Commit dc7d63a6 authored by Gavin Lee's avatar Gavin Lee
Browse files

add classification example

parent 569f2547
Pipeline #307960 passed with stage
in 4 minutes and 21 seconds
%% Cell type:code id:b8fa1017-812e-4a6b-9cf5-d69be1a718cb tags:
``` python
import numpy as np
```
%% Cell type:markdown id:e1a3af63-e023-4d31-b550-aede5d359569 tags:
# FROM
https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py
%% Cell type:code id:1bc291cd-0192-4611-b77e-7e941238e7fc tags:
``` python
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause
# Standard scientific Python imports
import matplotlib.pyplot as plt
# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
```
%% Cell type:code id:c521b6ba-f363-4444-a929-915b3962ec9c tags:
``` python
digits = datasets.load_digits()
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("Training: %i" % label)
```
%%%% Output: display_data
![]()
%% Cell type:code id:95670bd9-85b5-4f46-a8bb-68b9d8a7f210 tags:
``` python
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False
)
# Learn the digits on the train subset
clf.fit(X_train, y_train)
# Predict the value of the digit on the test subset
predicted = clf.predict(X_test)
```
%% Cell type:code id:b7e1ce4d-5b4b-43ba-97e4-f2c9e61339b4 tags:
``` python
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, prediction in zip(axes, X_test, predicted):
ax.set_axis_off()
image = image.reshape(8, 8)
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title(f"Prediction: {prediction}")
```
%%%% Output: display_data
![]()
%% Cell type:code id:13fe2601-b30e-4115-b783-7aef14622d6b tags:
``` python
from sklearn.metrics import accuracy_score
```
%% Cell type:code id:5170f2c1-d181-4b17-9d9f-cbbf25edfa3d tags:
``` python
from mlsconverters import export
```
%% Cell type:code id:21b21aec-5b1c-4b5c-8a2f-13ad9ff62ed0 tags:
``` python
## MLS to schema
acc = accuracy_score(y_test, predicted)
export(clf, evaluation_measure=(accuracy_score, acc))
```
%% Cell type:code id:196c78ba-3f86-429d-914e-19be240023c6 tags:
``` python
!cd ../; renku mls leaderboard
```
%%%% Output: stream
Error: Project version is outdated and a migration is required.
Run `renku migrate` command to fix the issue.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment