
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/mixture/plot_gmm_selection.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_mixture_plot_gmm_selection.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_mixture_plot_gmm_selection.py:


================================
Gaussian Mixture Model Selection
================================

This example shows that model selection can be performed with Gaussian Mixture
Models (GMM) using :ref:`information-theory criteria <aic_bic>`. Model selection
concerns both the covariance type and the number of components in the model.

In this case, both the Akaike Information Criterion (AIC) and the Bayes
Information Criterion (BIC) provide the right result, but we only demo the
latter as BIC is better suited to identify the true model among a set of
candidates. Unlike Bayesian procedures, such inferences are prior-free.

.. GENERATED FROM PYTHON SOURCE LINES 18-25

Data generation
---------------

We generate two components (each one containing `n_samples`) by randomly
sampling the standard normal distribution as returned by `numpy.random.randn`.
One component is kept spherical yet shifted and re-scaled. The other one is
deformed to have a more general covariance matrix.

.. GENERATED FROM PYTHON SOURCE LINES 25-36

.. code-block:: default


    import numpy as np

    n_samples = 500
    np.random.seed(0)
    C = np.array([[0.0, -0.1], [1.7, 0.4]])
    component_1 = np.dot(np.random.randn(n_samples, 2), C)  # general
    component_2 = 0.7 * np.random.randn(n_samples, 2) + np.array([-4, 1])  # spherical

    X = np.concatenate([component_1, component_2])








.. GENERATED FROM PYTHON SOURCE LINES 37-38

We can visualize the different components:

.. GENERATED FROM PYTHON SOURCE LINES 38-47

.. code-block:: default


    import matplotlib.pyplot as plt

    plt.scatter(component_1[:, 0], component_1[:, 1], s=0.8)
    plt.scatter(component_2[:, 0], component_2[:, 1], s=0.8)
    plt.title("Gaussian Mixture components")
    plt.axis("equal")
    plt.show()




.. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png
   :alt: Gaussian Mixture components
   :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 48-67

Model training and selection
----------------------------

We vary the number of components from 1 to 6 and the type of covariance
parameters to use:

- `"full"`: each component has its own general covariance matrix.
- `"tied"`: all components share the same general covariance matrix.
- `"diag"`: each component has its own diagonal covariance matrix.
- `"spherical"`: each component has its own single variance.

We score the different models and keep the best model (the lowest BIC). This
is done by using :class:`~sklearn.model_selection.GridSearchCV` and a
user-defined score function which returns the negative BIC score, as
:class:`~sklearn.model_selection.GridSearchCV` is designed to **maximize** a
score (maximizing the negative BIC is equivalent to minimizing the BIC).

The best set of parameters and estimator are stored in `best_parameters_` and
`best_estimator_`, respectively.

.. GENERATED FROM PYTHON SOURCE LINES 67-87

.. code-block:: default


    from sklearn.mixture import GaussianMixture
    from sklearn.model_selection import GridSearchCV


    def gmm_bic_score(estimator, X):
        """Callable to pass to GridSearchCV that will use the BIC score."""
        # Make it negative since GridSearchCV expects a score to maximize
        return -estimator.bic(X)


    param_grid = {
        "n_components": range(1, 7),
        "covariance_type": ["spherical", "tied", "diag", "full"],
    }
    grid_search = GridSearchCV(
        GaussianMixture(), param_grid=param_grid, scoring=gmm_bic_score
    )
    grid_search.fit(X)






.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <style>#sk-container-id-18 {color: black;background-color: white;}#sk-container-id-18 pre{padding: 0;}#sk-container-id-18 div.sk-toggleable {background-color: white;}#sk-container-id-18 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-18 label.sk-toggleable__label-arrow:before {content: "▸";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-18 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-18 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-18 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-18 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-18 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-18 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";}#sk-container-id-18 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-18 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-18 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-18 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-18 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-18 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-18 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-18 div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-18 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-18 div.sk-item {position: relative;z-index: 1;}#sk-container-id-18 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-18 div.sk-item::before, #sk-container-id-18 div.sk-parallel-item::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-18 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-18 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-18 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-18 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-18 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-18 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-18 div.sk-label-container {text-align: center;}#sk-container-id-18 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-18 div.sk-text-repr-fallback {display: none;}</style><div id="sk-container-id-18" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>GridSearchCV(estimator=GaussianMixture(),
                 param_grid={&#x27;covariance_type&#x27;: [&#x27;spherical&#x27;, &#x27;tied&#x27;, &#x27;diag&#x27;,
                                                 &#x27;full&#x27;],
                             &#x27;n_components&#x27;: range(1, 7)},
                 scoring=&lt;function gmm_bic_score at 0x7f6beca51440&gt;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-47" type="checkbox" ><label for="sk-estimator-id-47" class="sk-toggleable__label sk-toggleable__label-arrow">GridSearchCV</label><div class="sk-toggleable__content"><pre>GridSearchCV(estimator=GaussianMixture(),
                 param_grid={&#x27;covariance_type&#x27;: [&#x27;spherical&#x27;, &#x27;tied&#x27;, &#x27;diag&#x27;,
                                                 &#x27;full&#x27;],
                             &#x27;n_components&#x27;: range(1, 7)},
                 scoring=&lt;function gmm_bic_score at 0x7f6beca51440&gt;)</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-48" type="checkbox" ><label for="sk-estimator-id-48" class="sk-toggleable__label sk-toggleable__label-arrow">estimator: GaussianMixture</label><div class="sk-toggleable__content"><pre>GaussianMixture()</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-49" type="checkbox" ><label for="sk-estimator-id-49" class="sk-toggleable__label sk-toggleable__label-arrow">GaussianMixture</label><div class="sk-toggleable__content"><pre>GaussianMixture()</pre></div></div></div></div></div></div></div></div></div></div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 88-94

Plot the BIC scores
-------------------

To ease the plotting we can create a `pandas.DataFrame` from the results of
the cross-validation done by the grid search. We re-inverse the sign of the
BIC score to show the effect of minimizing it.

.. GENERATED FROM PYTHON SOURCE LINES 94-110

.. code-block:: default


    import pandas as pd

    df = pd.DataFrame(grid_search.cv_results_)[
        ["param_n_components", "param_covariance_type", "mean_test_score"]
    ]
    df["mean_test_score"] = -df["mean_test_score"]
    df = df.rename(
        columns={
            "param_n_components": "Number of components",
            "param_covariance_type": "Type of covariance",
            "mean_test_score": "BIC score",
        }
    )
    df.sort_values(by="BIC score").head()






.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <div>
    <style scoped>
        .dataframe tbody tr th:only-of-type {
            vertical-align: middle;
        }

        .dataframe tbody tr th {
            vertical-align: top;
        }

        .dataframe thead th {
            text-align: right;
        }
    </style>
    <table border="1" class="dataframe">
      <thead>
        <tr style="text-align: right;">
          <th></th>
          <th>Number of components</th>
          <th>Type of covariance</th>
          <th>BIC score</th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th>19</th>
          <td>2</td>
          <td>full</td>
          <td>1046.793662</td>
        </tr>
        <tr>
          <th>20</th>
          <td>3</td>
          <td>full</td>
          <td>1083.652535</td>
        </tr>
        <tr>
          <th>21</th>
          <td>4</td>
          <td>full</td>
          <td>1114.835102</td>
        </tr>
        <tr>
          <th>22</th>
          <td>5</td>
          <td>full</td>
          <td>1151.243322</td>
        </tr>
        <tr>
          <th>23</th>
          <td>6</td>
          <td>full</td>
          <td>1181.713463</td>
        </tr>
      </tbody>
    </table>
    </div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 111-122

.. code-block:: default

    import seaborn as sns

    sns.catplot(
        data=df,
        kind="bar",
        x="Number of components",
        y="BIC score",
        hue="Type of covariance",
    )
    plt.show()




.. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png
   :alt: plot gmm selection
   :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 123-139

In the present case, the model with 2 components and full covariance (which
corresponds to the true generative model) has the lowest BIC score and is
therefore selected by the grid search.

Plot the best model
-------------------

We plot an ellipse to show each Gaussian component of the selected model. For
such purpose, one needs to find the eigenvalues of the covariance matrices as
returned by the `covariances_` attribute. The shape of such matrices depends
on the `covariance_type`:

- `"full"`: (`n_components`, `n_features`, `n_features`)
- `"tied"`: (`n_features`, `n_features`)
- `"diag"`: (`n_components`, `n_features`)
- `"spherical"`: (`n_components`,)

.. GENERATED FROM PYTHON SOURCE LINES 139-174

.. code-block:: default


    from matplotlib.patches import Ellipse
    from scipy import linalg

    color_iter = sns.color_palette("tab10", 2)[::-1]
    Y_ = grid_search.predict(X)

    fig, ax = plt.subplots()

    for i, (mean, cov, color) in enumerate(
        zip(
            grid_search.best_estimator_.means_,
            grid_search.best_estimator_.covariances_,
            color_iter,
        )
    ):
        v, w = linalg.eigh(cov)
        if not np.any(Y_ == i):
            continue
        plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], 0.8, color=color)

        angle = np.arctan2(w[0][1], w[0][0])
        angle = 180.0 * angle / np.pi  # convert to degrees
        v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
        ellipse = Ellipse(mean, v[0], v[1], angle=180.0 + angle, color=color)
        ellipse.set_clip_box(fig.bbox)
        ellipse.set_alpha(0.5)
        ax.add_artist(ellipse)

    plt.title(
        f"Selected GMM: {grid_search.best_params_['covariance_type']} model, "
        f"{grid_search.best_params_['n_components']} components"
    )
    plt.axis("equal")
    plt.show()



.. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png
   :alt: Selected GMM: full model, 2 components
   :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  1.164 seconds)


.. _sphx_glr_download_auto_examples_mixture_plot_gmm_selection.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_gmm_selection.py <plot_gmm_selection.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_gmm_selection.ipynb <plot_gmm_selection.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
