<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <title>Gaussian Processes classification example: exploiting the probabilistic output — scikits.learn v0.6.0 documentation</title> <link rel="stylesheet" href="../../_static/nature.css" type="text/css" /> <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" /> <script type="text/javascript"> var DOCUMENTATION_OPTIONS = { URL_ROOT: '../../', VERSION: '0.6.0', COLLAPSE_INDEX: false, FILE_SUFFIX: '.html', HAS_SOURCE: true }; </script> <script type="text/javascript" src="../../_static/jquery.js"></script> <script type="text/javascript" src="../../_static/underscore.js"></script> <script type="text/javascript" src="../../_static/doctools.js"></script> <link rel="shortcut icon" href="../../_static/favicon.ico"/> <link rel="author" title="About these documents" href="../../about.html" /> <link rel="top" title="scikits.learn v0.6.0 documentation" href="../../index.html" /> <link rel="up" title="Examples" href="../index.html" /> <link rel="next" title="Gaussian Processes regression: basic introductory example" href="plot_gp_regression.html" /> <link rel="prev" title="Gaussian Processes regression: goodness-of-fit on the ‘diabetes’ dataset" href="gp_diabetes_dataset.html" /> </head> <body> <div class="header-wrapper"> <div class="header"> <p class="logo"><a href="../../index.html"> <img src="../../_static/scikit-learn-logo-small.png" alt="Logo"/> </a> </p><div class="navbar"> <ul> <li><a href="../../install.html">Download</a></li> <li><a href="../../support.html">Support</a></li> <li><a href="../../user_guide.html">User Guide</a></li> <li><a href="../index.html">Examples</a></li> <li><a href="../../developers/index.html">Development</a></li> </ul> <div class="search_form"> <div id="cse" style="width: 100%;"></div> <script src="http://www.google.com/jsapi" type="text/javascript"></script> <script type="text/javascript"> google.load('search', '1', {language : 'en'}); google.setOnLoadCallback(function() { var customSearchControl = new google.search.CustomSearchControl('016639176250731907682:tjtqbvtvij0'); customSearchControl.setResultSetSize(google.search.Search.FILTERED_CSE_RESULTSET); var options = new google.search.DrawOptions(); options.setAutoComplete(true); customSearchControl.draw('cse', options); }, true); </script> </div> </div> <!-- end navbar --></div> </div> <div class="content-wrapper"> <!-- <div id="blue_tile"></div> --> <div class="sphinxsidebar"> <div class="rel"> <a href="gp_diabetes_dataset.html" title="Gaussian Processes regression: goodness-of-fit on the ‘diabetes’ dataset" accesskey="P">previous</a> | <a href="plot_gp_regression.html" title="Gaussian Processes regression: basic introductory example" accesskey="N">next</a> | <a href="../../genindex.html" title="General Index" accesskey="I">index</a> </div> <h3>Contents</h3> <ul> <li><a class="reference internal" href="#">Gaussian Processes classification example: exploiting the probabilistic output</a></li> </ul> </div> <div class="content"> <div class="documentwrapper"> <div class="bodywrapper"> <div class="body"> <div class="section" id="gaussian-processes-classification-example-exploiting-the-probabilistic-output"> <span id="example-gaussian-process-plot-gp-probabilistic-classification-after-regression-py"></span><h1>Gaussian Processes classification example: exploiting the probabilistic output<a class="headerlink" href="#gaussian-processes-classification-example-exploiting-the-probabilistic-output" title="Permalink to this headline">¶</a></h1> <p>A two-dimensional regression exercise with a post-processing allowing for probabilistic classification thanks to the Gaussian property of the prediction.</p> <p>The figure illustrates the probability that the prediction is negative with respect to the remaining uncertainty in the prediction. The red and blue lines corresponds to the 95% confidence interval on the prediction of the zero level set.</p> <img alt="auto_examples/gaussian_process/images/plot_gp_probabilistic_classification_after_regression.png" class="align-center" src="auto_examples/gaussian_process/images/plot_gp_probabilistic_classification_after_regression.png" /> <p><strong>Python source code:</strong> <a class="reference download internal" href="../../_downloads/plot_gp_probabilistic_classification_after_regression.py"><tt class="xref download docutils literal"><span class="pre">plot_gp_probabilistic_classification_after_regression.py</span></tt></a></p> <div class="highlight-python"><div class="highlight"><pre><span class="k">print</span> <span class="n">__doc__</span> <span class="c"># Author: Vincent Dubourg <vincent.dubourg@gmail.com></span> <span class="c"># License: BSD style</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="kn">as</span> <span class="nn">np</span> <span class="kn">from</span> <span class="nn">scipy</span> <span class="kn">import</span> <span class="n">stats</span> <span class="kn">from</span> <span class="nn">scikits.learn.gaussian_process</span> <span class="kn">import</span> <span class="n">GaussianProcess</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">pl</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">cm</span> <span class="c"># Standard normal distribution functions</span> <span class="n">phi</span> <span class="o">=</span> <span class="n">stats</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">norm</span><span class="p">()</span><span class="o">.</span><span class="n">pdf</span> <span class="n">PHI</span> <span class="o">=</span> <span class="n">stats</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">norm</span><span class="p">()</span><span class="o">.</span><span class="n">cdf</span> <span class="n">PHIinv</span> <span class="o">=</span> <span class="n">stats</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">norm</span><span class="p">()</span><span class="o">.</span><span class="n">ppf</span> <span class="c"># A few constants</span> <span class="n">lim</span> <span class="o">=</span> <span class="mi">8</span> <span class="k">def</span> <span class="nf">g</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="sd">"""The function to predict (classification will then consist in predicting</span> <span class="sd"> whether g(x) <= 0 or not)"""</span> <span class="k">return</span> <span class="mf">5.</span> <span class="o">-</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="o">.</span><span class="mi">5</span> <span class="o">*</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">**</span> <span class="mf">2.</span> <span class="c"># Design of experiments</span> <span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.61611719</span><span class="p">,</span> <span class="o">-</span><span class="mf">6.00099547</span><span class="p">],</span> <span class="p">[</span><span class="mf">4.10469096</span><span class="p">,</span> <span class="mf">5.32782448</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.00000000</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.50000000</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">6.17289014</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.6984743</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.3109306</span><span class="p">,</span> <span class="o">-</span><span class="mf">6.93271427</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">5.03823144</span><span class="p">,</span> <span class="mf">3.10584743</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">2.87600388</span><span class="p">,</span> <span class="mf">6.74310541</span><span class="p">],</span> <span class="p">[</span><span class="mf">5.21301203</span><span class="p">,</span> <span class="mf">4.26386883</span><span class="p">]])</span> <span class="c"># Observations</span> <span class="n">y</span> <span class="o">=</span> <span class="n">g</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="c"># Instanciate and fit Gaussian Process Model</span> <span class="n">gp</span> <span class="o">=</span> <span class="n">GaussianProcess</span><span class="p">(</span><span class="n">theta0</span><span class="o">=</span><span class="mf">5e-1</span><span class="p">)</span> <span class="c"># Don't perform MLE or you'll get a perfect prediction for this simple example!</span> <span class="n">gp</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="c"># Evaluate real function, the prediction and its MSE on a grid</span> <span class="n">res</span> <span class="o">=</span> <span class="mi">50</span> <span class="n">x1</span><span class="p">,</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span> <span class="n">lim</span><span class="p">,</span> <span class="n">lim</span><span class="p">,</span> <span class="n">res</span><span class="p">),</span> \ <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span> <span class="n">lim</span><span class="p">,</span> <span class="n">lim</span><span class="p">,</span> <span class="n">res</span><span class="p">))</span> <span class="n">xx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">x1</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x1</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">x2</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x2</span><span class="o">.</span><span class="n">size</span><span class="p">)])</span><span class="o">.</span><span class="n">T</span> <span class="n">y_true</span> <span class="o">=</span> <span class="n">g</span><span class="p">(</span><span class="n">xx</span><span class="p">)</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">MSE</span> <span class="o">=</span> <span class="n">gp</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">eval_MSE</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="n">sigma</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">MSE</span><span class="p">)</span> <span class="n">y_true</span> <span class="o">=</span> <span class="n">y_true</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">res</span><span class="p">,</span> <span class="n">res</span><span class="p">))</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">res</span><span class="p">,</span> <span class="n">res</span><span class="p">))</span> <span class="n">sigma</span> <span class="o">=</span> <span class="n">sigma</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">res</span><span class="p">,</span> <span class="n">res</span><span class="p">))</span> <span class="n">k</span> <span class="o">=</span> <span class="n">PHIinv</span><span class="p">(</span><span class="o">.</span><span class="mi">975</span><span class="p">)</span> <span class="c"># Plot the probabilistic classification iso-values using the Gaussian property</span> <span class="c"># of the prediction</span> <span class="n">fig</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">111</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">axes</span><span class="o">.</span><span class="n">set_aspect</span><span class="p">(</span><span class="s">'equal'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">xticks</span><span class="p">([])</span> <span class="n">pl</span><span class="o">.</span><span class="n">yticks</span><span class="p">([])</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">([])</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_yticklabels</span><span class="p">([])</span> <span class="n">pl</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'$x_1$'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'$x_2$'</span><span class="p">)</span> <span class="n">cax</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">flipud</span><span class="p">(</span><span class="n">PHI</span><span class="p">(</span><span class="o">-</span> <span class="n">y_pred</span> <span class="o">/</span> <span class="n">sigma</span><span class="p">)),</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cm</span><span class="o">.</span><span class="n">gray_r</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> \ <span class="n">extent</span><span class="o">=</span><span class="p">(</span><span class="o">-</span> <span class="n">lim</span><span class="p">,</span> <span class="n">lim</span><span class="p">,</span> <span class="o">-</span> <span class="n">lim</span><span class="p">,</span> <span class="n">lim</span><span class="p">))</span> <span class="n">norm</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">matplotlib</span><span class="o">.</span><span class="n">colors</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">vmin</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span> <span class="n">cb</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">cax</span><span class="p">,</span> <span class="n">ticks</span><span class="o">=</span><span class="p">[</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="n">norm</span><span class="o">=</span><span class="n">norm</span><span class="p">)</span> <span class="n">cb</span><span class="o">.</span><span class="n">set_label</span><span class="p">(</span><span class="s">'${</span><span class="se">\\</span><span class="s">rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0</span><span class="se">\\</span><span class="s">right]$'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">y</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="p">[</span><span class="n">y</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="s">'r.'</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">12</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">y</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="p">[</span><span class="n">y</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="s">'b.'</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">12</span><span class="p">)</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="s">'k'</span><span class="p">,</span> \ <span class="n">linestyles</span><span class="o">=</span><span class="s">'dashdot'</span><span class="p">)</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">PHI</span><span class="p">(</span><span class="o">-</span> <span class="n">y_pred</span> <span class="o">/</span> <span class="n">sigma</span><span class="p">),</span> <span class="p">[</span><span class="mf">0.025</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="s">'b'</span><span class="p">,</span> \ <span class="n">linestyles</span><span class="o">=</span><span class="s">'solid'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">clabel</span><span class="p">(</span><span class="n">cs</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">PHI</span><span class="p">(</span><span class="o">-</span> <span class="n">y_pred</span> <span class="o">/</span> <span class="n">sigma</span><span class="p">),</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="s">'k'</span><span class="p">,</span> \ <span class="n">linestyles</span><span class="o">=</span><span class="s">'dashed'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">clabel</span><span class="p">(</span><span class="n">cs</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span> <span class="n">cs</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">PHI</span><span class="p">(</span><span class="o">-</span> <span class="n">y_pred</span> <span class="o">/</span> <span class="n">sigma</span><span class="p">),</span> <span class="p">[</span><span class="mf">0.975</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="s">'r'</span><span class="p">,</span> \ <span class="n">linestyles</span><span class="o">=</span><span class="s">'solid'</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">clabel</span><span class="p">(</span><span class="n">cs</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span> <span class="n">pl</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </pre></div> </div> </div> </div> </div> </div> <div class="clearer"></div> </div> </div> <div class="footer"> <p style="text-align: center">This documentation is relative to scikits.learn version 0.6.0<p> © 2010, scikits.learn developers (BSD Lincense). Created using <a href="http://sphinx.pocoo.org/">Sphinx</a> 1.0.5. Design by <a href="http://webylimonada.com">Web y Limonada</a>. </div> </body> </html>