Skip to content
Snippets Groups Projects
Commit e867a84d authored by Daniel Müller's avatar Daniel Müller :speech_balloon:
Browse files

Remove empty cell SVM solution

parent 4089f4c8
Branches
No related tags found
1 merge request!124Fix CNN + Algorithm Comparison
%% Cell type:markdown id:4faad863-6d6b-4c19-b9a5-67131aa6f123 tags:
# Support Vector Machine
In this exercise we will see another more different usage of the SVM. Maybe you already saw the problem of handwritten-digit-classification. Its a common problem for neural networks and we will now try to solve it by using a mathematical model.
Remember there is no right and wrong solution to this. The solution-notebook just shows you an example.
## Imports
These are the imports we will use.
%% Cell type:code id:6442efc2-de82-418c-8de3-f1ad68b506c9 tags:
``` python
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import scale
from sklearn.model_selection import validation_curve
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
```
%% Cell type:markdown id:021ba92c-bf02-4001-bd8f-3b78bd7ee5dd tags:
We import the data.
%% Cell type:code id:65b6375e-abe9-4af5-9cfc-335ffb539fe1 tags:
``` python
train_data = pd.read_csv("../data/Digits/train_digits.csv.gz", nrows=20000)
test_data = pd.read_csv("../data/Digits/test_digits.csv.gz",nrows=5000)
#train_data = train_data[:20000]
#test_data = test_data[:10000]
```
%% Cell type:markdown id:703956d1-0290-411c-92d1-d6d5cb844a1f tags:
## Task 1:
Visualize the data. Pyplot and seaborn is imported. This task is only for the purpose to understand the struture of the data.
%% Cell type:code id:b70150bb-d9b3-443b-ac30-d739b1f82973 tags:
``` python
# Visualizing the number of class and counts in the datasets
# The head() function is also really usefull
plt.plot(figure = (16,10))
g = sns.countplot(train_data["label"], palette = 'icefire')
plt.title('Number of digit classes')
train_data.label.astype('category').value_counts()
plt.show()
```
%% Output
%% Cell type:markdown id:e6bad466-e596-4c45-854f-7700c63c057e tags:
## Task 2: Preprocessing
At the moment the labels are within the actual data with the 'label' tag. And also we need to use scale() on the rest of the data.
You now need to prepocess the data. Also think about what the SVM needs and what we have in our data and do we actually need to do anything to it.
Finally we will use the train_text_split function to make a test data set.
%% Cell type:code id:beb2c23f-f90c-44dd-8467-a57c3a65399c tags:
``` python
# Separating the X and Y variable
y = train_data['label']
# Dropping the variable 'label' from X variable
X = train_data.drop(columns = 'label')
X_scaled = scale(X)
# train test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size = 0.3, train_size = 0.2 ,random_state = 10)
```
%% Cell type:markdown id:93f99db2-cfde-4763-b444-e2cfc9a25ded tags:
## Task 3: Model training
The third and final task is to actually load the model, train it and finally test its accuracy.
You already know the method from the actual notebook of the SVM.
Finally we will visualize the results. In this example it will be best to use a confusion matrix.
%% Cell type:code id:8546673f-3319-4ec0-9090-fa75a652e0ab tags:
``` python
model_linear = SVC(kernel='linear')
model_linear.fit(X_train, y_train)
# predict
y_pred = model_linear.predict(X_test)
```
%% Cell type:code id:1013637b-250d-400b-95ef-cee8a39cff15 tags:
``` python
# accuracy
print("accuracy:", metrics.accuracy_score(y_true=y_test, y_pred=y_pred), "\n")
# cm
mat = metrics.confusion_matrix(y_true=y_test, y_pred=y_pred)
plt.figure(figsize=(11,11))
sns.heatmap(mat, annot=True, cmap='Blues',fmt='d', cbar=False, square=True)
plt.xlabel('true label')
plt.ylabel('predicted label');
```
%% Output
accuracy: 0.9093333333333333
%% Cell type:code id:9cdb44b3-956c-4e52-aa09-e9e263d9f0a6 tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment