Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train ML model to correct predictions of week 3-4 & 5-6\n",
"\n",
"This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Synopsis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Method: `ML-based mean bias reduction`\n",
"\n",
"- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast\n",
"- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data used\n",
"\n",
"type: renku datasets\n",
"\n",
"Training-input for Machine Learning model:\n",
"- hindcasts of models:\n",
" - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`\n",
"\n",
"Forecast-input for Machine Learning model:\n",
"- real-time 2020 forecasts of models:\n",
" - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`\n",
"\n",
"Compare Machine Learning model forecast against against ground truth:\n",
"- `CPC` observations:\n",
" - `hindcast-like-observations_biweekly_deterministic.zarr`\n",
" - `forecast-like-observations_2020_biweekly_deterministic.zarr`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Resources used\n",
"- platform: renku\n",
"- memory: 8 GB\n",
"- processors: 2 CPU\n",
"- storage required: 10 GB"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Safeguards\n",
"\n",
"All points have to be [x] checked. If not, your submission is invalid.\n",
"\n",
"Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.\n",
"(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) \n",
"\n",
"If the organizers suspect overfitting, your contribution can be disqualified.\n",
"\n",
" - [x] We did not use 2020 observations in training (explicit overfitting and cheating)\n",
" - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)\n",
" - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.\n",
" - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).\n",
" - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.\n",
" - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.\n",
" - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Safeguards for Reproducibility\n",
"Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize\n",
" - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)\n",
" - [x] Code is well documented, readable and reproducible.\n",
" - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train."
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Todos to improve template\n",
"\n",
"This is just a demo.\n",
"\n",
"- [ ] use multiple predictor variables and two predicted variables\n",
"- [ ] for both `lead_time`s in one go\n",
"- [ ] consider seasonality, for now all `forecast_time` months are mixed\n",
"- [ ] make probabilistic predictions with `category` dim, for now works deterministic"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/xarray/backends/cfgrib_.py:27: UserWarning: Failed to load cfgrib - most likely there is a problem accessing the ecCodes library. Try `import cfgrib` to get the full error message\n",
" warnings.warn(\n"
]
}
],
"source": [
"from tensorflow.keras.layers import Input, Dense, Flatten\n",
"from tensorflow.keras.models import Sequential\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import xarray as xr\n",
"xr.set_options(display_style='text')\n",
"import numpy as np\n",
"\n",
"from dask.utils import format_bytes\n",
"import xskillscore as xs\n",
"%load_ext tensorboard"
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Get training data\n",
"\n",
"preprocessing of input data may be done in separate notebook/script"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hindcast\n",
"\n",
"get weekly initialized hindcasts"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"v='t2m'"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
"\u001b[0m\n"
]
}
],
"source": [
"# preprocessed as renku dataset\n",
"!renku storage pull ../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/xarray/backends/plugins.py:61: RuntimeWarning: Engine 'cfgrib' loading failed:\n",
"/opt/conda/lib/python3.8/site-packages/gribapi/_bindings.cpython-38-x86_64-linux-gnu.so: undefined symbol: codes_bufr_key_is_header\n",
" warnings.warn(f\"Engine {name!r} loading failed:\\n{ex}\", RuntimeWarning)\n"
]
}
],
"hind_2000_2019 = xr.open_zarr(\"../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr\", consolidated=True)"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Forecast"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
"\u001b[0m\n"
]
}
],
"source": [
"# preprocessed as renku dataset\n",
"!renku storage pull ../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"fct_2020 = xr.open_zarr(\"../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr\", consolidated=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Observations\n",
"categorized in terciles corresponding to hindcasts/forecasts"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
"\u001b[0m\n"
]
}
],
"!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr"
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Frozen(SortedKeysDict({'category': 3, 'forecast_time': 1060, 'latitude': 121, 'lead_time': 2, 'longitude': 240}))"
]
},
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"obs_2000_2019_p = xr.open_dataset(f'../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr', engine='zarr')\n",
"\n",
"obs_2000_2019_p.sizes"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"obs_2000_2019_p2 = obs_2000_2019_p.assign_coords(category=[0,1,2])\n",
"obs_2000_2019_p2 = (obs_2000_2019_p2 * obs_2000_2019_p2.category).sum('category', skipna=False).compute()\n",
"#obs_2000_2019_p2.isel(forecast_time=[2,4]).t2m.plot(col='lead_time', row='forecast_time')"
"obs_2020_p = xr.open_dataset(f'../data/forecast-like-observations_2020_biweekly_terciled.nc')\n",
"obs_2020_p2 = obs_2020_p.assign_coords(category=[0,1,2])\n",
"obs_2020_p2 = (obs_2020_p2 * obs_2020_p2.category).sum('category', skipna=False).compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ML model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)"
]
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fatal: destination path 'WeatherBench' already exists and is not an empty directory.\n"
"source": [
"# run once only and dont commit\n",
"!git clone https://github.com/pangeo-data/WeatherBench/"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(1, 'WeatherBench')\n",
"from WeatherBench.src.train_nn import DataGenerator, PeriodicConv2D, create_predictions\n",
"import tensorflow.keras as keras"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"bs=32\n",
"\n",
"import numpy as np\n",
"class DataGenerator(keras.utils.Sequence):\n",
" def __init__(self, fct, verif, lead_time, batch_size=bs, shuffle=True, load=True):\n",
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
" \"\"\"\n",
" Data generator for WeatherBench data.\n",
" Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly\n",
"\n",
" Args:\n",
" fct: forecasts from S2S models: xr.DataArray (xr.Dataset doesnt work properly)\n",
" verif: observations with same dimensionality (xr.Dataset doesnt work properly)\n",
" lead_time: Lead_time as in model\n",
" batch_size: Batch size\n",
" shuffle: bool. If True, data is shuffled.\n",
" load: bool. If True, datadet is loaded into RAM.\n",
" \n",
" Todo:\n",
" - use number in a better way, now uses only ensemble mean forecast\n",
" - dont use .sel(lead_time=lead_time) to train over all lead_time at once\n",
" - be sensitive with forecast_time, pool a few around the weekofyear given\n",
" - use more variables as predictors\n",
" - predict more variables\n",
" \"\"\"\n",
"\n",
" if isinstance(fct, xr.Dataset):\n",
" print('convert fct to array')\n",
" fct = fct.to_array().transpose(...,'variable')\n",
" self.fct_dataset=True\n",
" else:\n",
" self.fct_dataset=False\n",
" \n",
" if isinstance(verif, xr.Dataset):\n",
" print('convert verif to array')\n",
" verif = verif.to_array().transpose(...,'variable')\n",
" self.verif_dataset=True\n",
" else:\n",
" self.verif_dataset=False\n",
" \n",
" #self.fct = fct\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.lead_time = lead_time\n",
"\n",
" self.fct_data = fct.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
" self.fct_mean = self.fct_data.mean('forecast_time').compute()\n",
" self.fct_std = self.fct_data.std('forecast_time').compute()\n",
" \n",
" self.verif_data = verif.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
" #self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean is None else mean\n",
" #self.verif_std = self.verif_data.std('forecast_time').compute() if std is None else std\n",
" self.fct_data = (self.fct_data - self.fct_mean) / self.fct_std\n",
" #self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std\n",
" #self.verif_data = self.verif_data.astype('int32')\n",
" \n",
" self.n_samples = self.fct_data.forecast_time.size\n",
" self.forecast_time = self.fct_data.forecast_time\n",
"\n",
" self.on_epoch_end()\n",
"\n",
" # For some weird reason calling .load() earlier messes up the mean and std computations\n",
" if load:\n",
" # print('Loading data into RAM')\n",
" self.fct_data.load()\n",
"\n",
" def __len__(self):\n",
" 'Denotes the number of batches per epoch'\n",
" return int(np.ceil(self.n_samples / self.batch_size))\n",
"\n",
" def __getitem__(self, i):\n",
" 'Generate one batch of data'\n",
" idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]\n",
" # got all nan if nans not masked\n",
" X = self.fct_data.isel(forecast_time=idxs).values\n",
" y = self.verif_data.isel(forecast_time=idxs).values\n",
" return X, y\n",
"\n",
" def on_epoch_end(self):\n",
" 'Updates indexes after each epoch'\n",
" self.idxs = np.arange(self.n_samples)\n",
" if self.shuffle == True:\n",
" np.random.shuffle(self.idxs)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre><xarray.DataArray 'lead_time' ()>\n",
"array(1209600000000000, dtype='timedelta64[ns]')\n",
"Coordinates:\n",
" lead_time timedelta64[ns] 14 days\n",
"Attributes:\n",
" comment: lead_time describes bi-weekly aggregates. The pd.Timedelta corr...</pre>"
],
"text/plain": [
"<xarray.DataArray 'lead_time' ()>\n",
"array(1209600000000000, dtype='timedelta64[ns]')\n",
"Coordinates:\n",
" lead_time timedelta64[ns] 14 days\n",
"Attributes:\n",
" comment: lead_time describes bi-weekly aggregates. The pd.Timedelta corr..."
]
},
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 2 bi-weekly `lead_time`: week 3-4\n",
"lead = hind_2000_2019.isel(lead_time=0).lead_time\n",
"\n",
"lead"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"mask = obs_2000_2019_p2.isel(forecast_time=0, lead_time=0,drop=True).notnull()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## data prep: train, valid, test\n",
"\n",
"[Use the hindcast period to split train and valid.](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets) Do not use the 2020 data for testing!"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# time is the forecast_time\n",
"time_train_start,time_train_end='2000','2017' # train\n",
"time_valid_start,time_valid_end='2018','2019' # valid\n",
"time_test = '2020' # test"
]
},
{
"cell_type": "code",
"tattr='month'\n",
"tattr_label=1\n",
"attr='seaon'\n",
"attr_label='DJF'\n",
"hind_2000_2019 = hind_2000_2019.sel(forecast_time=getattr(hind_2000_2019.forecast_time.dt, tattr)==tattr_label)\n",
"obs_2000_2019_p2 = obs_2000_2019_p2.sel(forecast_time=getattr(obs_2000_2019_p2.forecast_time.dt, tattr)==tattr_label)"
"source": [
"dg_train = DataGenerator(\n",
" hind_2000_2019.mean('realization').sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
" obs_2000_2019_p2.sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
" lead_time=lead, batch_size=bs, shuffle=True, load=True)"
"source": [
"dg_valid = DataGenerator(\n",
" hind_2000_2019.mean('realization').sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
" obs_2000_2019_p2.sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
" lead_time=lead, batch_size=bs, shuffle=False, load=True)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# do not use, delete?\n",
"dg_test = DataGenerator(\n",
" fct_2020.mean('realization').sel(forecast_time=time_test)[v],\n",
" obs_2020_p2.sel(forecast_time=time_test)[v],\n",
" lead_time=lead, batch_size=bs, shuffle=False, load=True)"
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X, y = dg_train[0]\n",
"X.shape, y.shape"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X, y = dg_valid[0]\n",
"X.shape, y.shape"
]
},
{
"cell_type": "code",
"# any problem from normalizing?\n",
"i=4\n",
"#xr.DataArray(np.vstack([X[i],y[i]])).plot(yincrease=False, robust=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `fit`"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Convolution2D, MaxPooling2D, Flatten, Dense, Dropout"
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"periodic_conv2d_17 (Periodic (None, 121, 240, 64) 5248 \n",
"_________________________________________________________________\n",
"dropout_11 (Dropout) (None, 121, 240, 64) 0 \n",
"_________________________________________________________________\n",
"periodic_conv2d_18 (Periodic (None, 121, 240, 16) 25616 \n",
"_________________________________________________________________\n",
"_________________________________________________________________\n",
"=================================================================\n",
"Total params: 31,027\n",
"Trainable params: 31,027\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
" PeriodicConv2D(filters=64, kernel_size=9, conv_kwargs={'activation':'elu'},\n",
" input_shape=(121, 240, 1)),\n",
" #Dropout(dr),\n",
" #PeriodicConv2D(filters=16, kernel_size=9, conv_kwargs={'activation':'relu'}),\n",
" Dropout(dr),\n",
"cnn.summary()"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"optimizers = ['adam','RMSprop', 'sgd', keras.optimizers.Adam(learning_rate=0.01)]\n",
"losses = ['sparse_categorical_crossentropy', 'categorical_crossentropy']\n",
"\n",
"metrics=[keras.metrics.CategoricalAccuracy(),'accuracy']\n",
"\n",
"#cnn.compile(optimizer=optimizers[0], loss=losses[0], metrics=metrics[-1])"
]
},
{
"cell_type": "code",
"# https://stackoverflow.com/questions/34875944/how-to-write-a-custom-loss-function-in-tensorflow#37573411\n",
"\n",
"import tensorflow as tf\n",
"\n",
"def rps_loss(y_true, y_pred, sample_weight=None):\n",
" y_true = tf.one_hot(tf.cast(y_true,'int32'), depth=3)\n",
" Fc = tf.cumsum(y_pred, axis=-1)\n",
" Oc = tf.cumsum(y_true, axis=-1)\n",
" diff = tf.subtract(Fc, Oc)\n",
" rps = tf.reduce_sum(tf.square(diff), axis=-1)\n",
" # mask ocean\n",
" y_one = tf.reduce_sum(y_true, axis=-1)\n",
" rps = tf.multiply(rps, y_one)\n",
" # spatial mean # todo add weights\n",
" return rps\n",
"\n",
"\n",
"#xr.DataArray(tf.reduce_mean(rps_loss(obs_2020_p2.isel(lead_time=1)[v], preds_test[v].isel(lead_time=1)),axis=0),dims=['latitude','longitude'],coords={'latitude':mask.latitude,'longitude':mask.longitude}).plot(yincrease=True)"
"metadata": {},
"outputs": [],
"source": [
"cnn.compile(optimizer='adam', loss=rps_loss, metrics='accuracy')"
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"tensorboard_callback = keras.callbacks.TensorBoard(log_dir=\"./logs\")"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"6/6 [==============================] - ETA: 0s - loss: 0.0964 - accuracy: 0.0983WARNING:tensorflow:5 out of the last 16 calls to <function Model.make_test_function.<locals>.test_function at 0x7fd18e625940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
"6/6 [==============================] - 11s 1s/step - loss: 0.0965 - accuracy: 0.0980 - val_loss: 0.1008 - val_accuracy: 0.0815\n",
"Epoch 2/3\n",
"6/6 [==============================] - 8s 1s/step - loss: 0.0960 - accuracy: 0.0954 - val_loss: 0.1004 - val_accuracy: 0.0843\n",
"Epoch 3/3\n",
"6/6 [==============================] - 9s 1s/step - loss: 0.0943 - accuracy: 0.0982 - val_loss: 0.1015 - val_accuracy: 0.0842\n"
"history = cnn.fit(dg_train, epochs=3, validation_data=dg_valid, callbacks=[tensorboard_callback])"
]
},
{
"cell_type": "code",
"execution_count": 232,
"metadata": {},
"outputs": [],
"source": [
"#history.history['loss']"
]
},
{
"cell_type": "code",
"execution_count": 233,
"metadata": {},
"outputs": [],
"source": [
"#!jupyter labextension install jupyterlab_tensorboard"
"metadata": {},
"outputs": [],
"source": [
"#%tensorboard --logdir logs/fit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `predict`\n",
"\n",
"Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`."
"execution_count": 235,
"metadata": {},
"outputs": [],
"source": [
"# idea: build image classifier: show RPS on valid after each epoch\n",
"# https://www.tensorflow.org/tensorboard/image_summaries"
]
},
{
"cell_type": "code",
"execution_count": 255,
"metadata": {},
"outputs": [],
"source": [
"def _create_predictions(model, dg, lead):\n",
" preds = model.predict_proba(dg)\n",
" da = xr.DataArray(preds,\n",
" dims=['forecast_time', 'latitude', 'longitude','category'],\n",
" coords={'forecast_time': dg.fct_data.forecast_time, 'latitude': dg.fct_data.latitude,\n",
" 'longitude': dg.fct_data.longitude, 'category': ['below normal','near normal','above normal']},\n",
" )\n",
" da = da.where(mask[v]) # is that needed?\n",
" da = da.assign_coords(lead_time=lead)\n",
" return da"
]
},
{
"cell_type": "code",
"source": [
"#obs_2000_2019_p[v].isel(forecast_time=[2,20,25,-3,-2,-1]).sel(lead_time=lead).squeeze().plot(col='category', row='forecast_time')"
]
"preds = _create_predictions(cnn, dg_train, lead)\n",
"#preds.isel(forecast_time=[2,20,25,-3,-2,-1]).squeeze().plot(col='category', row='forecast_time', vmin=.13,vmax=.53, cmap='RdBu')"
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 274,
"metadata": {},
"outputs": [],
"source": [
"from scripts import add_valid_time_from_forecast_reference_time_and_lead_time\n",
"\n",
"# this is not useful but results have expected dimensions\n",
"# actually train for each lead_time\n",
"\n",
"def create_predictions(cnn, fct, obs, time):\n",
" preds_test=[]\n",
" for lead in fct.lead_time:\n",
" dg = DataGenerator(fct.mean('realization').sel(forecast_time=time)[v],\n",
" obs.sel(forecast_time=time)[v],\n",
" lead_time=lead, batch_size=bs, shuffle=False)\n",
" preds_test.append(_create_predictions(cnn, dg, lead))\n",
" preds_test = xr.concat(preds_test, 'lead_time')\n",
" preds_test['lead_time'] = fct.lead_time\n",
" # add valid_time coord\n",
" preds_test = add_valid_time_from_forecast_reference_time_and_lead_time(preds_test)\n",
" preds_test = preds_test.to_dataset(name=v)\n",
" # add fake var\n",
" preds_test['tp'] = preds_test['t2m']\n",
" return preds_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `predict` training period in-sample"
]
},
{
"cell_type": "code",
"import os\n",
"if os.environ['HOME'] == '/home/jovyan':\n",
" import pandas as pd\n",
" # assume on renku with small memory\n",
" step = 2\n",
" skill_list = []\n",
" for year in np.arange(int(time_train_start), int(time_train_end) -1, step): # loop over years to consume less memory on renku\n",
" preds_is = create_predictions(cnn, hind_2000_2019, obs_2000_2019_p2, time=slice(str(year), str(year+step-1))).compute()\n",
" skill_list.append(skill_by_year(preds_is))\n",
" skill = pd.concat(skill_list)\n",
"else: # with larger memory, simply do\n",
" preds_is = create_predictions(cnn, hind_2000_2019, obs_2000_2019_p2, time=slice(time_train_start, time_train_end))\n",
" skill = skill_by_year(preds_is)\n",
"skill"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `predict` validation period out-of-sample"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" </tr>\n",
" <tr>\n",
" <th>year</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2018</th>\n",
" </tr>\n",
" <tr>\n",
" <th>2019</th>\n",