<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://vitaliset.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://vitaliset.github.io/" rel="alternate" type="text/html" /><updated>2026-06-07T23:51:19+00:00</updated><id>https://vitaliset.github.io/feed.xml</id><title type="html">Vitali Set</title><subtitle>Personal blog with technical discussions, experiments, and reproducible deep-dives in data science and machine learning, by Carlo Lemos.</subtitle><author><name>Carlo Lemos</name></author><entry><title type="html">Evaluating ranking in regression</title><link href="https://vitaliset.github.io/evaluating-ranking-in-regression/" rel="alternate" type="text/html" title="Evaluating ranking in regression" /><published>2024-11-17T00:00:00+00:00</published><updated>2024-11-17T00:00:00+00:00</updated><id>https://vitaliset.github.io/evaluating-ranking-in-regression</id><content type="html" xml:base="https://vitaliset.github.io/evaluating-ranking-in-regression/"><![CDATA[<p><div align="justify">In supervised learning regression problems, the focus is generally on metrics that ensure the predicted value is close to the true value of the sample. Classic regression metrics are variations that involve the measure $| \hat{y_i} - y_i |$.</div></p>

<p><div align="justify">However, it is not always essential to predict the exact value of the target variable precisely, as in some applications, exactness is not critical to the final objective. In many cases, achieving a good ranking of the predictions is sufficient to meet the demands of the business problem. Of course, this depends on the context, but with proper ranking, we can approach the problem similarly to setting a <a href="https://vitaliset.github.io/threshold-dependent-opt/">threshold in binary classification</a> or, more generally, as a policy problem. In this case, the most appropriate cutoff point is identified through additional analysis to implement a desired treatment or action, such as targeting individuals with an expected credit card expense greater than $\delta$ for a new product marketing campaign. In most companies, the policy is structured around buckets of relevant percentiles, which are inherently based on ranking.</div></p>

<p><div align="justify">In other scenarios, such as income estimation, the regression model is often used as an auxiliary variable in subsequent models. These models, frequently ensembles of decision trees, inherently disregard the exact value of variables, considering only their rankings. If the final model is, for instance, a logistic regression or even a neural network, simple transformations are typically applied, altering the distribution of the values but maintaining monotonicity. Again, the exact values matter much less than the ranking.</div></p>

<p><div align="justify">From this perspective, it becomes clear that regression problems may require specific metrics to evaluate the quality of the ranking rather than relying solely on metrics that aim to minimize variations of $| \hat{y_i} - y_i |$.</div></p>

<p><div align="justify">$\oint$ <em>It is worth emphasizing that ranking-oriented metrics are particularly relevant in domains such as recommendation systems, where the primary objective is to provide an optimal ranking of items rather than precise value predictions. I believe that adapting recommendation system metrics could also be highly effective in addressing challenges in other domains. However, these adaptations might not be as straightforward as those discussed in this post.</em></div></p>

<hr />

<p><div align="justify">To illustrate our metrics, let’s assume we built three different models that produced various scores for the same prediction problem, with the test set defined by <code>y_true</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="n">random_state_0</span><span class="p">,</span> <span class="n">random_state_1</span><span class="p">,</span> <span class="n">random_state_2</span><span class="p">,</span> <span class="n">random_state_3</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="mi">42</span><span class="p">).</span><span class="n">randint</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>

<span class="n">y_true</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state_0</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">1_000</span><span class="p">)</span>

<span class="n">y_score_1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mi">3</span> <span class="o">+</span> <span class="n">y_true</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state_1</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">y_true</span><span class="p">))</span>
<span class="n">y_score_2</span> <span class="o">=</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">y_true</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state_2</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">y_true</span><span class="p">))</span>
<span class="n">y_score_3</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state_3</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">1_000</span><span class="p">)</span>

<span class="n">SCORES</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">([</span><span class="s">'y_score_1'</span><span class="p">,</span> <span class="s">'y_score_2'</span><span class="p">,</span> <span class="s">'y_score_3'</span><span class="p">],</span> <span class="p">[</span><span class="n">y_score_1</span><span class="p">,</span> <span class="n">y_score_2</span><span class="p">,</span> <span class="n">y_score_3</span><span class="p">]))</span>
</code></pre></div></div>

<p><div align="justify">Without delving into the specifics of how these scores were generated, the most natural and well-known way to evaluate these models would be using metrics such as $R^2$, $\textrm{RMSE}$, or some variation of these. These metrics are very useful but do not necessarily provide much insight into ranking.</div></p>

<p><div align="justify">In our example, by analyzing the $\textrm{RMSE}$, it seems that <code>y_score_3</code> is a good predictor.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn</span> <span class="kn">import</span> <span class="n">metrics</span>

<span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">rmse</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">metrics</span><span class="p">.</span><span class="n">mean_squared_error</span><span class="p">(</span><span class="n">y_true</span><span class="o">=</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_score</span><span class="p">))</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"RMSE for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">rmse</span><span class="si">:</span><span class="mf">6.3</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>RMSE for y_score_1: 58.205
RMSE for y_score_2:  2.242
RMSE for y_score_3:  1.400
</code></pre></div></div>

<hr />

<h2 id="spearmans-correlation">Spearman’s Correlation</h2>

<p><div align="justify">When we want to evaluate how a regression model ranks the data, it is natural to consider measures of correlation between two continuous variables. Consequently, an initial idea might be to use Pearson's correlation. However, Pearson's correlation focuses solely on linear relationships and does not account for the relative order of the values. Thus, even if the model accurately reproduces the order of the predicted values, if the transformation between the values is not close to linear, Pearson's correlation may not adequately reflect the quality of the ranking.</div></p>

<p><div align="justify">This is where Spearman's correlation ($\rho$) becomes an interesting metric, as it measures the similarity between the rankings of these variables [<a href="#bibliography">1</a>]. In other words, it evaluates whether the order of the values is consistent between the two. This makes Spearman's correlation particularly useful in problems where the relative position of the values is more important than their magnitudes.</div></p>

<p><div align="justify">$\oint$ <em>Spearman's correlation can be seen as a version of Pearson's correlation applied to the ranks of the variables instead of their original values. Under the hood, Spearman transforms the data by replacing each value with its position in the ranking and then calculates Pearson's correlation on these ranks.</em></div></p>

<p><div align="justify">If there are no ties in the ranks, the simplified formula for Spearman's correlation is given by</div></p>

\[\rho = 1 - \frac{6 \sum_{i=1}^n d_i^2}{n(n^2 - 1)},\]

<p><div align="justify">where $n$ is the total number of observations, and $d_i$ is the difference between the ranks of the same observation in the two variables. To calculate $d_i$, we first assign a rank to each value of the variables. For example, given a set of values $\{w_i\}_{i=1}^n$ and $\{z_i\}_{i=1}^n$, we sort each set separately and replace the values with their respective ranks. Then, for each observation $i$, we compute</div></p>

\[d_i = \text{rank}(w_i) - \text{rank}(z_i).\]

<p><div align="justify">The values of $\rho$ range between -1 and 1. A value of 1 indicates perfect ranking agreement, -1 indicates a complete inversion of the ranking, and 0 indicates no ranking relationship between the variables.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">scipy</span> <span class="kn">import</span> <span class="n">stats</span>

<span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">spearman</span> <span class="o">=</span> <span class="n">stats</span><span class="p">.</span><span class="n">spearmanr</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">).</span><span class="n">statistic</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Spearman for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">spearman</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Spearman for y_score_1: 0.99759
Spearman for y_score_2: 0.94718
Spearman for y_score_3: 0.01447
</code></pre></div></div>

<p><div align="justify">Using this new metric, we noticed that <code>y_score_1</code> and <code>y_score_2</code> stand out due to their ability to sort <code>y_true</code>.</div></p>

<hr />

<h2 id="kendalls-tau-correlation">Kendall’s Tau Correlation</h2>

<p><div align="justify">Another common metric for evaluating ranking is Kendall's Tau ($\tau$) concordance index. This metric measures the strength of association between two rankings by comparing pairs of observations and determining whether they are concordant or discordant [<a href="#bibliography">2</a>].</div></p>

<p><div align="justify">Two pairs $(w_i, z_i)$ and $(w_j, z_j)$ are considered:</div></p>

<ul>
  <li>
    <p><div align="justify">concordant: if the ranking of $w_i$ relative to $w_j$ is the same as the ranking of $z_i$ relative to $z_j$. Formally, this occurs when</div></p>
  </li>
</ul>

\[(w_i - w_j)(z_i - z_j) &gt; 0.\]

<ul>
  <li>
    <p><div align="justify">discordant: if the ranking of $w_i$ relative to $w_j$ is the opposite of that of $z_i$ relative to $z_j$. In other words,</div></p>
  </li>
</ul>

\[(w_i - w_j)(z_i - z_j) &lt; 0.\]

<p><div align="justify">The formula for Kendall's Tau is</div></p>

\[\tau = \frac{C - D}{\frac{1}{2} n(n-1)},\]

<p><div align="justify">where $C$ is the number of concordant pairs and $D$ is the number of discordant pairs. The denominator, $\frac{1}{2} n(n-1)$, represents the total number of possible pairs among $n$ observations.</div></p>

<p><div align="justify">Similar to Spearman's correlation, Kendall's Tau ranges between -1 and 1, with the same interpretation: when $\tau$ approaches 1, the rankings are highly concordant; when it approaches -1, the rankings are reversed; and when $\tau \approx 0$, there is no association between the rankings.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">kendall</span> <span class="o">=</span> <span class="n">stats</span><span class="p">.</span><span class="n">kendalltau</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">).</span><span class="n">statistic</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"kendall's tau for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">kendall</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>kendall's tau for y_score_1: 0.96163
kendall's tau for y_score_2: 0.80227
kendall's tau for y_score_3: 0.00976
</code></pre></div></div>

<p><div align="justify">$\oint$ <em>It is possible to adapt the metric to account for sample weights by assigning the weight of a pair as the product of the weights of the samples.</em></div></p>

<hr />

<h2 id="rocauc-for-classification">ROCAUC for Classification</h2>

<p><div align="justify">ROCAUC for classification is a very good binary classification metric for measuring ranking [<a href="#bibliography">3</a>]. It is a <a href="https://pibieta.github.io/imbalanced_learning/notebooks/Metrics%201%20-%20Intro%20%26%20ROC%20AUC.html#proof-of-probabilistic-interpretation-of-the-roc-auc">class imbalance-invariant metric</a> and has a perfect interpretation for the ranking problem, being, in my experience, the primary metric used in the industry for binary classification problems when ranking is the primary goal.</div></p>

<p><div align="justify">It is possible to <a href="https://pibieta.github.io/imbalanced_learning/notebooks/Metrics%201%20-%20Intro%20%26%20ROC%20AUC.html#proof-of-probabilistic-interpretation-of-the-roc-auc">prove</a> that in a binary classification problem with explanatory variables $X \in \mathcal{X}$ and $Y \in \{0, 1\}$, given a scoring/ranking function $f:\mathcal{X} \to \mathbb{R}$, then</div></p>

\[\text{ROCAUC}(f) = \mathbb{P}\left( f(X_i) &gt; f(X_j) \mid Y_i = 1, Y_j = 0 \right).\]

<p><div align="justify">In other words, if we select a random sample from class 1 and a random sample from class 0 in our binary classification problem, the ROCAUC coincides with the probability that the score given to the class 1 sample is greater than the score given to the class 0 sample.</div></p>

<p><div align="justify">Because of this probabilistic interpretation of the metric, a good ROCAUC for your classifier is equivalent to a good ranking when using your classifier as a means of ordering.</div></p>

<hr />

<h2 id="estimating-the-rocauc-via-the-wilcoxon-mann-whitney-statistic">Estimating the ROCAUC via the Wilcoxon-Mann-Whitney statistic</h2>

<p><div align="justify">The previous definition refers to the true ROCAUC value, rather than the estimated value we calculate using <a href="https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a>. In an observed random sample of $(X, Y)$, $\{(x_i, y_i)\}_{i=1}^n$, the probabilistic version can be estimated using the Wilcoxon-Mann-Whitney statistic as</div></p>

\[\frac{1}{n_0 n_1} \sum_{i : y_i = 1} \sum_{j : y_j = 0} \mathbb{1}\left( f(x_i) &gt; f(x_j) \right),\]

<p><div align="justify">where $n_0$ and $n_1$ are the numbers of elements in classes $0$ and $1$, respectively, and $\mathbb{1}\left(S\right)$ is the indicator function. $\mathbb{1}\left(S\right)$ is equal to 1 when the condition $S$ is true and 0 otherwise.</div></p>

<p><div align="justify">There are some variations of this statistic for more efficient computation, since in this form it requires a number of comparisons on the order of $\mathcal{O}(n_0 n_1)$, or $\mathcal{O}(n^2)$ if $n_1 \approx n_0$, which can be impractical [<a href="#bibliography">3</a>]. The simplest basic version is to perform this sampling only a sufficiently large number $N$ of times, resulting in the version</div></p>

\[\widehat{\text{ROCAUC}}(f) = \frac{1}{N} \sum_{(i,j) : y_i = 1, y_j = 0} \mathbb{1}\left( f(x_i) &gt; f(x_j) \right).\]

<h2 id="rocauc-for-regression">ROCAUC for Regression</h2>

<p><div align="justify">This probabilistic interpretation motivates us to make a clever variation and use something similar for the regression problem [<a href="#bibliography">4</a>]. If we replace the condition $y_i = 1, y_j = 0$ with $y_i &gt; y_j$, we can construct a generic ranking probability metric for regression problems as</div></p>

\[\widehat{\text{ROCAUC}}(f) = \frac{1}{N} \sum_{(i,j): y_i &gt; y_j}  \mathbb{1}\left( f(x_i) &gt; f(x_j) \right).\]

<p><div align="justify">$\oint$ <em>Just like with Kendall's tau, it's possible to adapt the metric to account for sample weights by assigning the weight of a pair as the product of the weights of the samples.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn</span> <span class="kn">import</span> <span class="n">utils</span>

<span class="k">def</span> <span class="nf">regression_roc_auc</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">):</span>
    <span class="s">"""Compute the generalized ROC AUC for regression tasks.

    This function calculates the probability that the predicted values maintain
    the correct order relative to the true values, specifically for pairs where
    y_true[i] &gt; y_true[j].

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True continuous target values.

    y_score : array-like of shape (n_samples,)
        Predicted continuous target values.

    Returns
    -------
    score : float
        The computed generalized ROC AUC score for regression.
    """</span>
    <span class="n">y_true</span> <span class="o">=</span> <span class="n">utils</span><span class="p">.</span><span class="n">check_array</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">ensure_2d</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
    <span class="n">y_score</span> <span class="o">=</span> <span class="n">utils</span><span class="p">.</span><span class="n">check_array</span><span class="p">(</span><span class="n">y_score</span><span class="p">,</span> <span class="n">ensure_2d</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

    <span class="n">total_pairs</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">correct_orderings</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="c1"># Efficiently compute the metric without explicit loops
</span>    <span class="c1"># Create a mask for all pairs where y_true[i] &gt; y_true[j]
</span>    <span class="n">diff_matrix</span> <span class="o">=</span> <span class="n">y_true</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="p">:]</span>
    <span class="n">valid_pairs</span> <span class="o">=</span> <span class="n">diff_matrix</span> <span class="o">&gt;</span> <span class="mi">0</span>

    <span class="c1"># Count total valid pairs
</span>    <span class="n">total_pairs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">valid_pairs</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">total_pairs</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="c1"># If no valid pairs, return 0.5 (equivalent to random ordering)
</span>        <span class="k">return</span> <span class="mf">0.5</span>

    <span class="c1"># Compare predictions for valid pairs
</span>    <span class="n">pred_diff</span> <span class="o">=</span> <span class="n">y_score</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">y_score</span><span class="p">[</span><span class="bp">None</span><span class="p">,</span> <span class="p">:]</span>
    <span class="n">correct_orderings</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">((</span><span class="n">pred_diff</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="n">valid_pairs</span><span class="p">)</span>

    <span class="n">score</span> <span class="o">=</span> <span class="n">correct_orderings</span> <span class="o">/</span> <span class="n">total_pairs</span>
    <span class="k">return</span> <span class="n">score</span>

<span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">roc_auc</span> <span class="o">=</span> <span class="n">regression_roc_auc</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"ROCAUC for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">roc_auc</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>ROCAUC for y_score_1: 0.98081
ROCAUC for y_score_2: 0.90113
ROCAUC for y_score_3: 0.50488
</code></pre></div></div>

<p><div align="justify">These metrics are quite useful, but if your regression problem is highly imbalanced, you may encounter some difficulties. I have worked on regression problems where over 99.5% of the data had values equal to 0, with only a small fraction having any associated value. Since many values will be tied, depending on the correlation implementation you use, your previous metrics may become artificially inflated or deflated, without a clear rationale to identify the issue. In the case of ROCAUC, with many ties in $y_i = y_j$, discarding numerous samples might result in a less reliable value with high variance.</div></p>

<hr />

<h2 id="ranking-curve">Ranking Curve</h2>

<p><div align="justify">The ranking curve (I’m not sure if this curve has an official name) is interesting because it is very simple and intuitive. The process involves ranking your sample based on the predicted variable, dividing it into "buckets" according to percentiles, and then calculating the mean or another relevant positional statistic for each bucket. For example, if you divide the sample into 10 buckets, the third bucket would contain the elements with values falling between the 20th and 30th percentiles, and you would compute the mean of these values.</div></p>

<p><div align="justify">The idea is that, if your score ranks the sample well, then the elements with the highest values will cluster at one end, and those with the lowest values will cluster at the other. As a result, the resulting graph will have a steep slope.</div></p>

<p><div align="justify">$\oint$ <em>I usually divide the buckets into 10, but this number is a parameter you can adjust as desired, depending on the level of detail you want to observe. The issue is that the greater the detail, the noisier the result will be due to smaller sample sizes. However, by using a bootstrap method, you can plot a confidence interval for analysis.</em></div></p>

<p><div align="justify">$\oint$ <em>This construction is not a <a href="https://vitaliset.github.io/covariate-shift-1-qqplot/">QQ-plot</a>, but understanding how a QQ-plot works may help you grasp the construction of this metric, even though this curve is much simpler.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="s">"""Compute the ranking curve for a regression task.

    Calculates statistics of `y_true` values across `n_buckets`  of `y_score`
    values, ordered by the predicted scores. It can be used to  assess the
    distribution or trends of true values as a function of predicted scores.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True continuous target values.

    y_score : array-like of shape (n_samples,)
        Predicted continuous target values.

    n_buckets : int, default=10
        The number of buckets to divide the sorted `y_score` values into.

    statistic : {'mean', 'median'} or callable, default='mean'
        The statistic to compute for `y_true` values in each bucket.
        - If 'mean', computes the mean of `y_true` in each bucket.
        - If 'median', computes the median of `y_true` in each bucket.
        - If callable, applies the callable function to the `y_true` values in each bucket.

    Returns
    -------
    bucket_positions : ndarray of shape (n_buckets,)
        The positions of the buckets, indexed from 1 to `n_buckets`.

    bucket_values : ndarray of shape (n_buckets,)
        The computed statistic for `y_true` values in each bucket.
    """</span>
    <span class="n">sorted_indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">y_score</span><span class="p">)</span>
    <span class="n">y_true_sorted</span> <span class="o">=</span> <span class="n">y_true</span><span class="p">[</span><span class="n">sorted_indices</span><span class="p">]</span>

    <span class="n">bucket_edges</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">y_true</span><span class="p">),</span> <span class="n">n_buckets</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">int</span><span class="p">)</span>
    <span class="n">bucket_values</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="k">if</span> <span class="n">statistic</span> <span class="o">==</span> <span class="s">'mean'</span><span class="p">:</span>
        <span class="n">stat_func</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span>
    <span class="k">elif</span> <span class="n">statistic</span> <span class="o">==</span> <span class="s">'median'</span><span class="p">:</span>
        <span class="n">stat_func</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">median</span>
    <span class="k">elif</span> <span class="nb">callable</span><span class="p">(</span><span class="n">statistic</span><span class="p">):</span>
        <span class="n">stat_func</span> <span class="o">=</span> <span class="n">statistic</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">raise</span> <span class="nb">ValueError</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_buckets</span><span class="p">):</span>
        <span class="n">start</span><span class="p">,</span> <span class="n">end</span> <span class="o">=</span> <span class="n">bucket_edges</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">bucket_edges</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
        <span class="n">bin_values</span> <span class="o">=</span> <span class="n">y_true_sorted</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">bin_values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">bucket_stat</span> <span class="o">=</span> <span class="n">stat_func</span><span class="p">(</span><span class="n">bin_values</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">bucket_stat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">nan</span>
        <span class="n">bucket_values</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">bucket_stat</span><span class="p">)</span>

    <span class="n">bucket_positions</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_buckets</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">bucket_positions</span><span class="p">,</span> <span class="n">bucket_values</span>

<span class="n">N_BUCKETS</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">ordering_curve_dict</span> <span class="o">=</span> <span class="p">{}</span>

<span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">bins</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="o">=</span> <span class="n">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">N_BUCKETS</span><span class="p">,</span> <span class="s">'mean'</span><span class="p">)</span>
    <span class="n">ordering_curve_dict</span><span class="p">[</span><span class="n">score_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">ordering_curve</span>
</code></pre></div></div>

<p><div align="justify">It’s useful to compare the curve with a random model that would uniformly distribute <code>y_true</code> across all bins, meaning that the mean for every bucket would be the same, as there would be no relationship between the order and <code>y_true</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">130</span><span class="p">)</span>

<span class="n">ax</span><span class="p">.</span><span class="n">hlines</span><span class="p">(</span><span class="n">y_true</span><span class="p">.</span><span class="n">mean</span><span class="p">(),</span> <span class="nb">min</span><span class="p">(</span><span class="n">bins</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">bins</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">'random ordering'</span><span class="p">,</span> <span class="n">colors</span><span class="o">=</span><span class="s">'k'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="ow">in</span> <span class="n">ordering_curve_dict</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">bins</span><span class="p">,</span> <span class="n">ordering_curve</span><span class="p">,</span> <span class="s">'-o'</span><span class="p">,</span> <span class="n">markeredgecolor</span><span class="o">=</span><span class="s">'k'</span><span class="p">,</span> <span class="n">markeredgewidth</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">score_name</span><span class="p">)</span>

<span class="n">ax</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">bins</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Mean of y for each bucket"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"buckets"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/evaluating_ranking_in_regression/output_17_0.png" /></center></div></p>

<p><div align="justify">It is also very useful to transform this plot into numerical values that can be used to compare models during hyperparameter optimization. Some of the metrics I like to use include:</div></p>

<ul>
  <li>
    <p><div align="justify">The value of the last bucket.</div></p>
  </li>
  <li>
    <p><div align="justify">The value of the first bucket.</div></p>
  </li>
  <li>
    <p><div align="justify">The difference between the last and the first bucket (which is equivalent to the mean of the variations through a telescoping sum).</div></p>
  </li>
  <li>
    <p><div align="justify">The slope of a linear regression fitted to the points.</div></p>
  </li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">last_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="o">=</span> <span class="n">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="n">n_buckets</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="n">statistic</span><span class="p">)</span>
    <span class="o">*</span><span class="n">_</span><span class="p">,</span> <span class="n">last_bucket</span> <span class="o">=</span> <span class="n">ordering_curve</span>
    <span class="k">return</span> <span class="n">last_bucket</span>

<span class="k">def</span> <span class="nf">first_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="o">=</span> <span class="n">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="n">n_buckets</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="n">statistic</span><span class="p">)</span>
    <span class="n">first_bucket</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="n">ordering_curve</span>
    <span class="k">return</span> <span class="n">first_bucket</span>

<span class="k">def</span> <span class="nf">diff_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="o">=</span> <span class="n">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="n">n_buckets</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="n">statistic</span><span class="p">)</span>
    <span class="n">first_bucket</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span><span class="p">,</span> <span class="n">last_bucket</span> <span class="o">=</span> <span class="n">ordering_curve</span>
    <span class="k">return</span> <span class="n">last_bucket</span> <span class="o">-</span> <span class="n">first_bucket</span>

<span class="kn">from</span> <span class="nn">sklearn</span> <span class="kn">import</span> <span class="n">linear_model</span>

<span class="k">def</span> <span class="nf">linear_regression_coefficient_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">ordering_curve</span> <span class="o">=</span> <span class="n">ranking_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">,</span> <span class="n">n_buckets</span><span class="o">=</span><span class="n">n_buckets</span><span class="p">,</span> <span class="n">statistic</span><span class="o">=</span><span class="n">statistic</span><span class="p">)</span>
    <span class="n">x_values</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n_buckets</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">y_values</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">ordering_curve</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

    <span class="n">model</span> <span class="o">=</span> <span class="n">linear_model</span><span class="p">.</span><span class="n">LinearRegression</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">x_values</span><span class="p">,</span> <span class="n">y_values</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">model</span><span class="p">.</span><span class="n">coef_</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<p><div align="justify">The higher the value of the last bin, the more concentrated the selected values of <code>y_true</code> are in the higher range.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">last_bucket</span> <span class="o">=</span> <span class="n">last_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Last bucket for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">last_bucket</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Last bucket for y_score_1: 1.79617
Last bucket for y_score_2: 1.70048
Last bucket for y_score_3: 0.12308
</code></pre></div></div>

<p><div align="justify">The lower the value of the first bin, the more concentrated the selected values of <code>y_true</code> are in the lower range.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">first_bucket</span> <span class="o">=</span> <span class="n">first_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"First bucket for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">first_bucket</span><span class="si">:</span><span class="mf">8.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>First bucket for y_score_1: -1.76345
First bucket for y_score_2: -1.70674
First bucket for y_score_3:  0.07232
</code></pre></div></div>

<p><div align="justify">The greater the difference between the last bin and the first bin, the better separated the values with low scores are from those with higher scores.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">diff_bucket</span> <span class="o">=</span> <span class="n">diff_bucket_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Diff bucket for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">diff_bucket</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Diff bucket for y_score_1: 3.55962
Diff bucket for y_score_2: 3.40723
Diff bucket for y_score_3: 0.05076
</code></pre></div></div>

<p><div align="justify">The steeper the slope of the linear regression curve fitted to the points, the more tilted the points are, indicating better ranking.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">score_name</span><span class="p">,</span> <span class="n">y_score</span> <span class="ow">in</span> <span class="n">SCORES</span><span class="p">.</span><span class="n">items</span><span class="p">():</span>
    <span class="n">lr_bucket</span> <span class="o">=</span> <span class="n">linear_regression_coefficient_ordering_curve</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_score</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Linear regression coefficient for </span><span class="si">{</span><span class="n">score_name</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">lr_bucket</span><span class="si">:</span><span class="mf">7.5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Linear regression coefficient for y_score_1: 0.34367
Linear regression coefficient for y_score_2: 0.32808
Linear regression coefficient for y_score_3: 0.00722
</code></pre></div></div>

<p><div align="justify">$\oint$ <em>Adding sample weights to this curve is considerably more tedious, as you need to split the percentiles based on the sum of the weights, <a href="https://github.com/dihanster/datalib/issues/17#issue-1688825236">but it’s not impossible</a>. :)</em></div></p>

<p><div align="justify">$\oint$ <em>This curve is also really good for evaluating ranking performance for classification problems.</em></div></p>

<hr />

<h2 id="final-considerations">Final considerations</h2>

<p><div align="justify">Although regression models often optimize metrics based on $| \hat{y_i} - y_i |$, I hope this discussion has inspired reflection on the limitations of such metrics. They may not always be the most appropriate choice and can sometimes obscure the true problem of interest.</div></p>

<p><div align="justify">The ranking metrics introduced here are each highly valuable, complementing one another depending on the specific context and problem at hand. Instead of striving for a single, universally applicable metric, it is often more effective to evaluate these metrics collectively. In practice, they tend to align and reinforce each other, offering a richer and more nuanced understanding of model performance.</div></p>

<p><div align="justify">Moreover, I encourage you to tweak existing metrics or develop custom variations which can often uncover fresh perspectives on a problem. The ultimate goal is not merely to assign a score to a model but to ensure it aligns with the problem's objectives and delivers outcomes that are meaningful and actionable.</div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p align="justify">[1] <a href="https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient">Spearman's rank correlation coefficient. Wikipedia.</a></p>

<p align="justify">[2] <a href="https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient">Kendall rank correlation coefficient. Wikipedia.</a></p>

<p align="justify">[3] <a href="https://pibieta.github.io/imbalanced_learning/notebooks/Introduction.html">Imbalanced Binary Classification - A survey with code. Alessandro Morita, Juan Pablo Ibieta, Carlo Lemos.</a></p>

<p align="justify">[4] <a href="https://towardsdatascience.com/how-to-calculate-roc-auc-score-for-regression-models-c0be4fdf76bb">You Can Compute ROC Curve Also for Regression Models. Samuele Mazzanti.</a></p>
<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/evaluating_ranking_in_regression">repository of this post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;basic&quot;]" /><summary type="html"><![CDATA[MSE and MAE can be misleading if your regression goal is to rank.]]></summary></entry><entry><title type="html">The R² score does not vary between 0 and 1</title><link href="https://vitaliset.github.io/r2-score/" rel="alternate" type="text/html" title="The R² score does not vary between 0 and 1" /><published>2023-10-12T00:00:00+00:00</published><updated>2023-10-12T00:00:00+00:00</updated><id>https://vitaliset.github.io/r2-score</id><content type="html" xml:base="https://vitaliset.github.io/r2-score/"><![CDATA[<p><div align="justify">Este texto tem uma versão em português que pode ser encontrada no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/r_squared">repositório de experimentos</a>.</div></p>

<hr />

<p><div align="justify">The coefficient of determination, known as $R^2$, is a fundamental metric in regression analyses. However, its definition and interpretation are not always straightforward. Indeed, there are several ways to define the $R^2$ and, although all are equivalent, each offers a different interpretative nuance. Some of these interpretations are more intuitive, facilitating an immediate understanding of the possible values, while others can lead to misunderstandings.</div></p>

<p><div align="justify">The current version of scikit-learn, in its docstring for <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html"><code>sklearn.metrics.r2_score</code></a>, mentions that the $R^2$ can range from negative infinity to 1. However, it&#39;s not uncommon to find data scientists claiming that the range of possible values for $R^2$ is strictly between 0 and 1. One of the reasons for this discrepancy lies in the classical interpretation of $R^2$, which is traditionally understood as the proportion of variance explained by the model relative to the total variance of the target variable [<a href="#bibliography">1</a>].</div></p>

<p><div align="justify">Throughout this text, I will address the interpretation that I consider most enlightening and relevant. With it, I hope to clarify some peculiarities of the $R^2$ and highlight its importance as a robust metric, frequently referred to in regression problems.</div></p>

<hr />

<h2 id="mean-squared-error-and-the-choice-of-a-constant-model">Mean Squared Error and the choice of a constant model</h2>

<p><div align="justify">The $R^2$ is a common metric in regression. However, often the first metric introduced for regression problems is the Mean Squared Error (MSE). The MSE of a model $h$ on a dataset $S = \{ (x_i, y_i) \}_{i=1}^n$ is defined by</div></p>

\[\textrm{MSE}(h) = \frac{1}{n} \sum_{i=1}^n \left(y_i - h(x_i)\right)^2,\]

<p><div align="justify">where we chose not to denote the dependence on $S$ in order to keep the notation more streamlined.</div></p>

<p><div align="justify">Given this definition, an intriguing question arises: if you had to create a model that was merely a constant, which value would you choose? Many might answer that they would choose the mean, which is indeed one of the correct answers. However, why not consider the median, mode, or some other descriptive statistic?</div></p>

<p><div align="justify">The answer to this question is intrinsically linked to the cost function we wish to optimize. This choice is, in fact, a problem of decision theory [<a href="#bibliography">2</a>]. For instance, if the goal is to optimize the MSE, then we would need to choose an $\alpha \in \mathbb{R}$ such that $h_\alpha(x) = \alpha$ minimizes the $\textrm{MSE}(h_\alpha)$. Mathematically, this is expressed as</div></p>

\[\alpha^* = \arg\min_{\alpha \in \mathbb{R}} \textrm{MSE}(h_\alpha) = \arg\min_{\alpha \in \mathbb{R}} \left( \frac{1}{n} \sum_{i=1}^n \left(y_i - \alpha\right)^2 \right).\]

<p><div align="justify">This function may seem complex at first glance, but it becomes simpler when considering only $\alpha$ as the free variable, which is how we approach this optimization problem. By expanding the square and performing the summation, we have a polynomial function of degree 2 in $\alpha$ in the form</div></p>

\[\frac{1}{n} \sum_{i=1}^n \left(y_i - \alpha\right)^2 = \frac{1}{n} \sum_{i=1}^n \left(y_i^2 -2\alpha y_i + \alpha^2 \right) = \alpha^2  + \left(\frac{-2}{n} \sum_{i=1}^n y_i\right) \alpha+ \left(\frac{1}{n} \sum_{i=1}^n y_i^2\right).\]

<p><div align="justify">In a quadratic function of the form $(a\,\alpha^2 + b\,\alpha + c)$, where $a&gt;0$, the minimum occurs at the vertex of the parabola, located at $\frac{-b}{2a}$. Thus, in our context, the minimum is</div></p>

\[\alpha^* = \frac{\left(\frac{-2}{n} \sum_{i=1}^n y_i\right)}{-2} = \frac{1}{n} \sum_{i=1}^n y_i = \bar{y}.\]

<p><div align="justify">This means that, when minimizing the MSE, the optimal constant value is the average of the target $\bar{y}$ for this set. I encourage validating this result using other unconstrained optimization techniques such as identifying critical points followed by analyzing the concavity of the function.</div></p>

<p><div align="justify">This behavior changes when considering other metrics [<a href="#bibliography">3</a>]. For example, to minimize the Mean Absolute Error (MAE), the constant value that optimizes it is the median, while the value that optimizes accuracy is the mode, and for pinball loss, it&#39;s the associated quantile. It&#39;s important to emphasize that if we consider <code>sample_weight</code>, all these statistics should be computed in a weighted manner.</div></p>

<p><div align="justify">$\oint$ <em>This is used in defining prediction values for the nodes of decision trees. Looking at the scikit-learn code for trees, we notice that, depending on the criterion, the <a href="https://github.com/scikit-learn/scikit-learn/blob/d7a114413d1f11bf5f7029cd519c9a29a66b1890/sklearn/tree/_criterion.pyx#L1036"><code>node_value</code></a> can vary. It&#39;s adjusted to reflect the value that minimizes the loss when the node makes a constant prediction. For example, for the MSE criterion, the leaf&#39;s prediction is the average of the target of the training samples that fall in that leaf, while for the MAE criterion, it&#39;s the median.</em></div></p>

<p><div align="justify">$\oint$ <em>In practice, a model that predicts the target&#39;s average isn&#39;t feasible because to calculate the average of the test set, you would need to know the $y_i$ values of that sample. However, this perspective is useful for comparing a basic model with your model, as we will discuss next.</em></div></p>

<hr />

<h2 id="r-as-a-comparison-between-your-model-and-a-constant-model">R² as a comparison between your model and a constant model</h2>

<p><div align="justify">Suppose I develop a model to predict a person&#39;s age based on their online behavior and obtain an MSE of 25 years squared. This number on its own might not be very informative. One way to interpret it is to calculate the Root Mean Squared Error, that is, $\textrm{RMSE} = \sqrt{\textrm{MSE}}$, resulting in an error of about 5 years. This value is more intuitive (I admit that, internally, I tend to think in terms of MAE), but it still doesn&#39;t provide a relative comparison like &quot;is it possible to get a value significantly lower than this?&quot;. The $R^2$ might not answer this question directly, but it aids in this evaluation.</div></p>

<p><div align="justify">We&#39;ve already discussed a simple model that can serve as a benchmark. Imagine that the mean-based model already produces an MSE of 30 years squared. Suddenly, our previous model, which might have seemed excellent, doesn&#39;t stand out as much. If a simple model already achieves an MSE just slightly higher than the current model, is it worth implementing the more complex model in a production environment?</div></p>

<p><div align="justify">The interpretation I have of $R^2$ is precisely this comparison. Its formula is</div></p>

\[R^2(h) = 1 - \frac{\textrm{MSE}(h)}{\textrm{MSE}(\bar{y})},\]

<p><div align="justify">where $\bar{y}$ represents the average of the target in the set $S$ in which we are evaluating the model.</div></p>

<p><div align="justify">With this, we can understand the possible values of $R^2$:</div></p>

<ul>
  <li>
    <p><div align="justify">If $R^2 = 1$, it means that $\textrm{MSE}(h) = 0$; that is, the model is perfect.</div></p>
  </li>
  <li>
    <p><div align="justify">If $R^2 = 0$, we have $\textrm{MSE}(h) = \textrm{MSE}(\bar{y})$, indicating that our model is as effective as a model that simply provides the target&#39;s average.</div></p>
  </li>
  <li>
    <p><div align="justify">For an $R^2$ between 0 and 1, we have $0 &lt; \textrm{MSE}(h) &lt; \textrm{MSE}(\bar{y})$. This indicates that the model has an error greater than zero, but less than that of a constant model based on the average.</div></p>
  </li>
  <li>
    <p><div align="justify">A negative $R^2$ suggests that $\textrm{MSE}(h) &gt; \textrm{MSE}(\bar{y})$, meaning our model is less accurate than one that always provides the average.</div></p>
  </li>
</ul>

<p><div align="justify">This interpretation helps in understanding the values obtained when using the function <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html"><code>sklearn.metrics.r2_score</code></a>. In the previous example, we would have an $R^2$ of $(1 - 25/30) \approx 0.17$, indicating a model that surpasses the simple model, but not very significantly.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">fetch_california_housing</span>
<span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LinearRegression</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">mean_squared_error</span><span class="p">,</span> <span class="n">r2_score</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
    <span class="o">*</span><span class="n">fetch_california_housing</span><span class="p">(</span><span class="n">return_X_y</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
    <span class="n">test_size</span><span class="o">=</span><span class="mf">0.33</span><span class="p">,</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">lr</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">evaluate_model</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"MSE: </span><span class="si">{</span><span class="n">mean_squared_error</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"R^2: </span><span class="si">{</span><span class="n">r2_score</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    
<span class="n">y_pred_lr</span> <span class="o">=</span>  <span class="n">lr</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="n">evaluate_model</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred_lr</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MSE: 0.5369686543372444
R^2: 0.5970494128783965
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_mean_test</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">.</span><span class="n">mean</span><span class="p">()</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">evaluate_model</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_mean_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MSE: 1.3325918152222385
R^2: 0.0
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_pred_terrible_model</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">evaluate_model</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred_terrible_model</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MSE: 5.6276808369101445
R^2: -3.2231092616846126
</code></pre></div></div>

<p><div align="justify">Although a model with an $R^2$ of zero might seem like the lowest achievable threshold, in reality, this metric uses a baseline model with data leakage. In practice, we build our models using training data, and in scenarios subject to &quot;dataset shift,&quot; there can be significant changes in fundamental statistics, such as the average.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_mean_train</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">.</span><span class="n">mean</span><span class="p">()</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">evaluate_model</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_mean_train</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MSE: 1.3326257277946882
R^2: -2.5448582275933163e-05
</code></pre></div></div>

<p><div align="justify">Regardless of these nuances, interpreting the $R^2$ in this way offers a valuable comparative mindset. It&#39;s always essential to compare your model with simple baselines, whether with established business rules or with more basic models, like a constant.</div></p>

<hr />

<h2 id="generalization-of-r-beyond-mse">Generalization of R² beyond MSE</h2>

<p><div align="justify">The notion of comparison with a basic or simple model can easily be generalized to other metrics, as long as we know which statistics to use as a baseline. Considering this, I propose extending this idea to the MAE using the median $\tilde{y}$ as the baseline model</div></p>

\[R^2_{\textrm{MAE}}(h) = 1 - \frac{\textrm{MAE}(h)}{\textrm{MAE}(\tilde{y})},\]

<p><div align="justify">where</div></p>

\[\textrm{MAE}(h) = \frac{1}{n} \sum_{i=1}^n \left| y_i - h(x_i) \right|.\]

<p><div align="justify">Thus, the $R^2_{\textrm{MAE}}$ provides a way to evaluate the model&#39;s performance relative to a simple baseline, using the MAE as the error metric.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">mean_absolute_error</span>

<span class="k">def</span> <span class="nf">r2_score_mae</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    <span class="n">mae_model</span> <span class="o">=</span> <span class="n">mean_absolute_error</span><span class="p">(</span><span class="n">y_true</span><span class="o">=</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    <span class="n">y_median_true</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">median</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span>
    <span class="n">mae_median</span> <span class="o">=</span> <span class="n">mean_absolute_error</span><span class="p">(</span>
        <span class="n">y_true</span><span class="o">=</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_median_true</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">mae_model</span> <span class="o">/</span> <span class="n">mae_median</span>

<span class="k">def</span> <span class="nf">evaluate_model_mae</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"MAE: </span><span class="si">{</span><span class="n">mean_absolute_error</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"R^2_MAE: </span><span class="si">{</span><span class="n">r2_score_mae</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

<span class="n">evaluate_model_mae</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred_lr</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MAE: 0.5295710106684688
R^2_MAE: 0.40256278728026484
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_median_test</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">median</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span>
<span class="n">evaluate_model_mae</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_median_test</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>MAE: 0.8864044612448619
R^2_MAE: 0.0
</code></pre></div></div>

<hr />

<h2 id="final-considerations">Final considerations</h2>

<p><div align="justify">The misconception that $R^2$ varies only between 0 and 1 originates from a simplified interpretation of its most common meaning: the proportion of the target&#39;s variance that is explained by the independent variables, which suggests that the value lies between 0% and 100%. In practice, in many cases, $R^2$ indeed falls within this range. However, in situations where the model is inferior to a simple horizontal model (i.e., a straight line representing the average), $R^2$ can have negative values. This negative scenario is often underestimated by the statistical community, as it is usually associated with overfitting situations. Rarely will a linear regression that tends to suffer from underfitting be inferior to the horizontal model included in the hypothesis space of linear regression.</div></p>

<p><div align="justify">Throughout this post, we analyzed some of the reasons why $R^2$ is such an interesting metric and widely used in regression problems. By understanding the implicit comparison with a baseline model, we gain a valuable perspective on the relative performance of our model, normalizing the less informative values of MSE when viewed in isolation. Moreover, the interpretation proposed here truly allows us to understand the resulting values in a clear and objective manner.</div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p><div align="justify">[1] <a href="https://en.wikipedia.org/wiki/Coefficient_of_determination">Coefficient of determination. Wikipedia.</a></div></p>

<p><div align="justify">[2] <a href="https://vfossaluza.github.io/InfBayes/TeoDec.html">Introdução à Teoria da Decisão. Fundamentos de Inferência Bayesiana. Victor Fossaluza e Luís Gustavo Esteves.</a></div></p>

<p><div align="justify">[3] <a href="https://vfossaluza.github.io/InfBayes/Estimacao.html#estima%C3%A7%C3%A3o-pontual">Estimação Pontual. Fundamentos de Inferência Bayesiana. Victor Fossaluza e Luís Gustavo Esteves.</a></div></p>
<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/r_squared">repository of this post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;🇧🇷&quot;, &quot;basic&quot;]" /><summary type="html"><![CDATA[R&#178; as a comparison of the MSE with a simple baseline model and its potential for generalization.]]></summary></entry><entry><title type="html">Conformal prediction in CATE estimation</title><link href="https://vitaliset.github.io/cqr-cate/" rel="alternate" type="text/html" title="Conformal prediction in CATE estimation" /><published>2023-07-17T00:00:00+00:00</published><updated>2023-07-17T00:00:00+00:00</updated><id>https://vitaliset.github.io/cqr-cate</id><content type="html" xml:base="https://vitaliset.github.io/cqr-cate/"><![CDATA[<p><div align="justify">As we've discussed in the post about <a href="https://vitaliset.github.io/conditional-density-estimation/">Conditional Density Estimation</a>, having a sense of confidence associated with your prediction is important for decision making <a href="#bibliography">[1]</a>, and this is no different in applications of causal inference. Here, estimating confidence intervals for the Conditional Average Treatment Effect (CATE) can greatly enhance the validity of causal inference studies.</div></p>

<p><div align="justify">In the binary treatment $T\in\{0, 1\}$ scenario, CATE is defined as the expected difference in outcomes $Y$ when an individual with certain observable characteristics is treated versus when the same individual is not treated. Mathematically, depending on the school of causal inference that you come from, we can write &quot;the average difference in expected potential outcomes conditional on the same covariates $Z=z$&quot; as <a href="#bibliography">[2, 3, 4]</a>.</div></p>

\[\begin{align*}
    \textrm{CATE}_{T, Y}(z) &amp;= \mathbb{E}(Y| do(T=1), Z=z) - \mathbb{E}(Y| do(T=0), Z=z)\\
    &amp;= \mathbb{E}(Y_1 | Z=z) - \mathbb{E}(Y_0 | Z=z).
\end{align*}\]

<p><div align="justify">CATE helps to estimate the effect of a treatment at an individual level, taking into account the specific characteristics of each instance. This is incredibly valuable in many fields of industry where understanding the effect of a treatment ($T$) on different subpopulations ($Z$) helps in creating personalized treatment plans depending on the desired outcome ($Y$).</div></p>

<hr />

<h2 id="brief-review-of-confounder-control">Brief review of confounder control</h2>

<p><div align="justify">It's common to use as $Z$ a set of variables that, in the CATE conditionals, satisfies the backdoor criterion — or, in Rubin's theory, renders $T$ conditionally ignorable — to measure the causal effect of $T$ on $Y$, i.e., $(Y_0, Y_1) \, \bot \, T \, | \, Z$. This is important because, in this scenario, $Z$ controls confounders <a href="#bibliography">[2]</a>, and we have the causal identification given by</div></p>

\[f(z|do(T=t)) = f(z)\textrm{, and }f(y|do(T=t), Z=z) = f(y|T=t, Z=z).\]

<p><div align="justify">Consequently <a href="#bibliography">[2]</a></div></p>

\[\mathbb{E}(Y|do(T=t), Z=z) = \mathbb{E}(Y|T=t, Z=z).\]

<p><div align="justify">This relationship is crucial as it enables us to estimate this quantity using any supervised machine learning model. This technique is known as the adjustment formula and has different flavors such as meta-learners and matching <a href="#bibliography">[2, 3]</a>.</div></p>

<p><div align="justify">Despite its utility, applying conformal prediction for estimating CATE in the above scenario is not straightforward. Since binary CATE involves estimating two quantities, it is necessary to combine the prediction intervals of these two estimates in some way. We will discuss how we can do this without any parametric assumptions.</div></p>

<p><div align="justify">$\oint$ <em>In continuous treatment scenarios, my experience has shown that $\mathbb{E}(Y| do(T=t), Z=z)$ provides more information than CATE, which is defined as the derivative of this expectation with respect to $t$. It is easier to directly use conformal prediction in $\mathbb{E}(Y| do(T=t), Z=z)$ as this scenario can be interpreted just as a regression, when using the adjustment formula. On the other hand, if you really need to use CATE, this interval estimate is much more complicated, and bootstrap strategies would be the approach I would use. If you have another idea, please reach out!</em></div></p>

<hr />

<h2 id="creating-the-dataset">Creating the dataset</h2>

<p><div align="justify">To illustrate our application, we will use a simple causal graph where $Z$ will act as a confounder, serving as a set that satisfies the backdoor criterion.</div></p>

<p><div align="justify"><center><img src="/assets/img/cqr_cate/output_4_0.png" /></center></div></p>

<p><div align="justify">With structural causal graph given by</div></p>

\[U_Z \sim \textrm{Uniform}(-\pi, \pi)\textrm{, with }g_Z(u_Z) = u_Z,\]

\[U_T \sim \textrm{Uniform}(0, 1)\textrm{, with }\]

\[g_T(u_T, z) = \mathbb{1}(u_T \leq 0.05 + 0.9\, \sigma(z))\textrm{, where }\sigma(x) = \frac{1}{1 + \exp(-x)},\]

\[U_Y \sim \mathcal{N}(0, 1)\textrm{, with }\]

\[g_Y(u_Y, z, t) = \mathbb{1}(t=0) (10 \sin(z)) + \mathbb{1}(t=1) (10 \cos(z)) + 0.5 (1 + t\,|z|)\,u_Y.\]

<p><div align="justify">Note that we are in a suitable scenario to apply causal inference as the positivity assumption <a href="#bibliography">[5]</a> is guaranteed; in other words, it holds that</div></p>

\[0 &lt; \mathbb{P}(T=t | Z=z) &lt; 1 \textrm{, }\forall t \in \textrm{Im}(T)= \{ 0, 1\}, z \in \textrm{Im}(Z) = (0, 1).\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">adapted_sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="k">return</span> <span class="mf">0.05</span> <span class="o">+</span> <span class="mf">0.9</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/cqr_cate/output_8_0.png" /></center></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">func_0</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
    <span class="k">return</span> <span class="mi">10</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">func_1</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
    <span class="k">return</span> <span class="mi">10</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">generate_data</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">obs</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">rs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state</span><span class="p">).</span><span class="n">randint</span><span class="p">(</span>
        <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">int64</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">4</span>
    <span class="p">)</span>

    <span class="n">Z_obs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">rs</span><span class="p">[</span><span class="mi">0</span><span class="p">]).</span><span class="n">uniform</span><span class="p">(</span><span class="n">low</span><span class="o">=-</span><span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">g_T_noised</span><span class="p">(</span><span class="n">Z</span><span class="p">):</span>
        <span class="k">return</span> <span class="p">(</span>
            <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">rs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
            <span class="p">.</span><span class="n">binomial</span><span class="p">(</span><span class="n">n</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">adapted_sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">))</span>
            <span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
        <span class="p">)</span>

    <span class="n">T_obs</span> <span class="o">=</span> <span class="n">g_T_noised</span><span class="p">(</span><span class="n">Z_obs</span><span class="p">)</span>

    <span class="n">noise</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">rs</span><span class="p">[</span><span class="mi">3</span><span class="p">]).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">g_Y</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">noise</span><span class="p">):</span>
        <span class="k">return</span> <span class="p">(</span>
            <span class="n">np</span><span class="p">.</span><span class="n">select</span><span class="p">(</span><span class="n">condlist</span><span class="o">=</span><span class="p">[</span><span class="n">T</span><span class="p">],</span> <span class="n">choicelist</span><span class="o">=</span><span class="p">[</span><span class="n">func_1</span><span class="p">(</span><span class="n">Z</span><span class="p">)],</span> <span class="n">default</span><span class="o">=</span><span class="n">func_0</span><span class="p">(</span><span class="n">Z</span><span class="p">))</span>
            <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">T</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">Z</span><span class="p">))</span> <span class="o">*</span> <span class="n">noise</span>
        <span class="p">)</span>

    <span class="n">Y_obs</span> <span class="o">=</span> <span class="n">g_Y</span><span class="p">(</span><span class="n">T_obs</span><span class="p">,</span> <span class="n">Z_obs</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span>
    <span class="n">Y_cf</span> <span class="o">=</span> <span class="n">g_Y</span><span class="p">(</span><span class="o">~</span><span class="n">T_obs</span><span class="p">,</span> <span class="n">Z_obs</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">generate_df</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span>
            <span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">T</span><span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span> <span class="n">Z</span><span class="p">,</span> <span class="n">Y</span><span class="p">]).</span><span class="n">T</span><span class="p">,</span>
            <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"confounder"</span><span class="p">,</span> <span class="s">"target"</span><span class="p">],</span>
        <span class="p">)</span>

    <span class="n">df_obs</span> <span class="o">=</span> <span class="n">generate_df</span><span class="p">(</span><span class="n">T_obs</span><span class="p">,</span> <span class="n">Z_obs</span><span class="p">,</span> <span class="n">Y_obs</span><span class="p">)</span>
    <span class="n">df_cf</span> <span class="o">=</span> <span class="n">generate_df</span><span class="p">(</span><span class="o">~</span><span class="n">T_obs</span><span class="p">,</span> <span class="n">Z_obs</span><span class="p">,</span> <span class="n">Y_cf</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">df_obs</span><span class="p">,</span> <span class="n">df_cf</span>

<span class="n">df_obs</span><span class="p">,</span> <span class="n">df_cf</span> <span class="o">=</span> <span class="n">generate_data</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">50_000</span><span class="p">,</span> <span class="n">obs</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Since we are dealing with synthetic data, we can observe both the observational and the counterfactual scenarios. In this instance, we can actually derive $Y_1 - Y_0$ for each example. Thus, we will be able to evaluate our estimates using a test set that's separate from the training set, as is typical in supervised scenarios.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">df_train</span><span class="p">,</span> <span class="n">df_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
    <span class="n">df_obs</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">target_cf</span><span class="o">=</span><span class="n">df_cf</span><span class="p">.</span><span class="n">target</span><span class="p">),</span>
    <span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">df_train_t0</span> <span class="o">=</span> <span class="n">df_train</span><span class="p">.</span><span class="n">query</span><span class="p">(</span><span class="s">"treatment == 0"</span><span class="p">)</span>
<span class="n">df_train_t1</span> <span class="o">=</span> <span class="n">df_train</span><span class="p">.</span><span class="n">query</span><span class="p">(</span><span class="s">"treatment == 1"</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">return_TZ_y</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">backdoor_set_list</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">df</span><span class="p">.</span><span class="nb">filter</span><span class="p">(</span><span class="n">backdoor_set_list</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">df</span><span class="p">.</span><span class="n">target</span><span class="p">)</span>

<span class="n">backdoor_set</span> <span class="o">=</span> <span class="p">[</span><span class="s">"confounder"</span><span class="p">]</span>

<span class="n">XZ_train_t0</span><span class="p">,</span> <span class="n">y_train_t0</span> <span class="o">=</span> <span class="n">return_TZ_y</span><span class="p">(</span><span class="n">df_train_t0</span><span class="p">,</span> <span class="n">backdoor_set</span><span class="p">)</span>
<span class="n">XZ_train_t1</span><span class="p">,</span> <span class="n">y_train_t1</span> <span class="o">=</span> <span class="n">return_TZ_y</span><span class="p">(</span><span class="n">df_train_t1</span><span class="p">,</span> <span class="n">backdoor_set</span><span class="p">)</span>

<span class="n">XZ_test</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">return_TZ_y</span><span class="p">(</span><span class="n">df_test</span><span class="p">,</span> <span class="n">backdoor_set</span><span class="p">)</span>
</code></pre></div></div>

<hr />

<h2 id="positivity-assumption">Positivity assumption</h2>

<p><div align="justify">One assumption, often overlooked in Pearl's theory but crucial to test for good estimation, is the positivity assumption. As we observed earlier, this assumption is satisfied in our synthetic data, but in a real-life scenario, it would require validation.</div></p>

<p><div align="justify">$\oint$ <em>If you are in a situation where you are applying a &quot;<a href="https://en.wikipedia.org/wiki/Multi-armed_bandit">$\varepsilon$-greedy strategy</a>&quot; in your population to have randomization, then this assumption is ensured. This emphasizes the importance of a continuous experimentation process in a product based on causal inference.</em></div></p>

<p><div align="justify">The importance of the positivity assumption being satisfied is immediate: How do we predict what happens with $Y$ when $T$ has a certain value in regions of $Z$ where no individual has received such treatment? Naturally, the problem becomes impossible, or your approximation becomes very bad because it uses distant examples to make predictions for that point.</div></p>

<p><div align="justify">The common approach to ensure this is to employ a model that estimates $T$ using $Z$ and then evaluate it. If this model demonstrates exceptional performance, it implies that the relationship is likely deterministic, thereby violating the positivity assumption. In the case of binary treatment, which is our scenario, we can assess a reasonably well-calibrated model (or calibrate the model ourselves <a href="#bibliography">[6]</a>) and examine the distribution of probabilities.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">roc_auc_score</span>
<span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>

<span class="n">positivity_assumption_check_estimator</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">df_train</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"target"</span><span class="p">,</span> <span class="s">"target_cf"</span><span class="p">]),</span> <span class="n">df_train</span><span class="p">.</span><span class="n">treatment</span><span class="p">)</span>

<span class="n">roc_auc_score</span><span class="p">(</span>
    <span class="n">df_test</span><span class="p">.</span><span class="n">treatment</span><span class="p">,</span> <span class="n">positivity_assumption_check_estimator</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">XZ_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.8370462957096292
</code></pre></div></div>

<p><div align="justify">The <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a> already suggests that we are in a plausible scenario to assume the positivity assumption. When there exist deterministic regions in the relationship between $T$ and $Z$, this typically results in a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a> close to 1.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">calibration_stuff</span> <span class="kn">import</span> <span class="n">calibration_curve</span>

<span class="n">probs</span> <span class="o">=</span> <span class="n">positivity_assumption_check_estimator</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">XZ_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">prob_true</span><span class="p">,</span> <span class="n">prob_pred</span><span class="p">,</span> <span class="n">size_bin</span> <span class="o">=</span> <span class="n">calibration_curve</span><span class="p">(</span><span class="n">df_test</span><span class="p">.</span><span class="n">treatment</span><span class="p">,</span> <span class="n">probs</span><span class="p">,</span> <span class="n">n_bins</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">plot</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="s">"--"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">prob_true</span><span class="p">,</span> <span class="n">prob_pred</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="p">(</span><span class="mf">0.1</span> <span class="o">*</span> <span class="n">size_bin</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">"k"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"True probability of bin"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Mean predicted probability of bin"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">probs</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">21</span><span class="p">),</span> <span class="n">weights</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">probs</span><span class="p">)</span> <span class="o">/</span> <span class="n">probs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Histogram of predicted probability"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/cqr_cate/output_15_0.png" /></center></div></p>

<p><div align="justify">Indeed, after confirming that the model is reasonably calibrated, we can observe that the probability histograms do not contain examples with probabilities close to 0 or 1. This suggests that we are in an appropriate scenario for estimating CATE.</div></p>

<p><div align="justify">$\oint$ <em>The scenario of continuous treatment is slightly more complex, but evaluating regression metrics can provide a good intuition of this relationship. Another viable technique is to discretize the treatment and analyze these probabilities in a manner similar to the approach used for the binary case.</em></div></p>

<hr />

<h2 id="conformalized-quantile-regression">Conformalized Quantile Regression</h2>

<p><div align="justify">Quantile regression with pinball loss <a href="#bibliography">[7]</a> is a suitable method for predicting conditional quantiles of a target variable. However, these estimates $Q_{\beta}$ and $Q_{1-\beta}$ of the conditional quantiles $\beta \in (0, 1)$ and $1 - \beta$, respectively, usually do not satisfy the coverage property which requires $\mathbb{P}((Y|Z=z) \in (Q_{\beta}$, $Q_{1-\beta})) \geq 1 - 2 \beta$ <a href="#bibliography">[8]</a>.</div></p>

<p><div align="justify">Conformalized Quantile Regression utilizes the previous quantile regression approach, but with a correction in these predictions of conditional quantiles, thereby ensuring marginal coverage <a href="#bibliography">[1, 8]</a>.</div></p>

<p><div align="justify">We can implement a version of Conformalized Quantile Regression using the aforementioned strategy, trying to follow the <a href="https://scikit-learn.org/stable/developers/develop.html">scikit-learn standards</a> and using <a href="https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMRegressor.html"><code>lightgbm.LGBMRegressor</code></a> with `objective="quantile"` as the quantile regressor.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="kn">from</span> <span class="nn">lightgbm</span> <span class="kn">import</span> <span class="n">LGBMRegressor</span>
<span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">loguniform</span>
<span class="kn">from</span> <span class="nn">sklearn.base</span> <span class="kn">import</span> <span class="n">BaseEstimator</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">make_scorer</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">RandomizedSearchCV</span><span class="p">,</span> <span class="n">KFold</span>
<span class="kn">from</span> <span class="nn">sklearn.utils.validation</span> <span class="kn">import</span> <span class="n">check_X_y</span><span class="p">,</span> <span class="n">check_is_fitted</span><span class="p">,</span> <span class="n">_check_sample_weight</span>
<span class="kn">from</span> <span class="nn">statsmodels.stats.weightstats</span> <span class="kn">import</span> <span class="n">DescrStatsW</span>

<span class="k">class</span> <span class="nc">ConformalizedQuantileRegression</span><span class="p">(</span><span class="n">BaseEstimator</span><span class="p">):</span>
    <span class="s">"""
    Conformalized Quantile Regression with LGBMRegressor.

    This estimator provides prediction intervals for one dimension
    regression tasks by using CQR with LightGBM.

    Parameters
    ----------
    alpha : float, default=0.05
        Determines the size of the prediction interval. For example,
        alpha=0.05 results in a 95% coverage prediction interval.

    calibration_size : float, default=0.2
        The proportion of the dataset to be used for the calibration set
        which computes the conformity scores.

    random_state : int, RandomState instance or None, default=None
        Controls the randomness for reproducibility.

    n_iter_cv : int, default=10
        Number of parameter settings that are sampled in RandomizedSearchCV
        for the LightGBM model during fit.
    """</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">calibration_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">n_iter_cv</span><span class="o">=</span><span class="mi">10</span>
    <span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">calibration_size</span> <span class="o">=</span> <span class="n">calibration_size</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">random_state</span> <span class="o">=</span> <span class="n">random_state</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_iter_cv</span> <span class="o">=</span> <span class="n">n_iter_cv</span>

    <span class="k">def</span> <span class="nf">_quantile_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">quantile</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">sample_weights</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">weighted_errors</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">quantile</span> <span class="o">-</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">&lt;</span> <span class="n">y_pred</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">sample_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">weighted_errors</span> <span class="o">*=</span> <span class="n">sample_weights</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">weighted_errors</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_return_quantile_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">quantile</span><span class="p">):</span>
        <span class="n">quantile_scorer</span> <span class="o">=</span> <span class="n">make_scorer</span><span class="p">(</span>
            <span class="n">partial</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">_quantile_loss</span><span class="p">,</span> <span class="n">quantile</span><span class="o">=</span><span class="n">quantile</span><span class="p">),</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">False</span>
        <span class="p">)</span>

        <span class="k">return</span> <span class="n">RandomizedSearchCV</span><span class="p">(</span>
            <span class="n">estimator</span><span class="o">=</span><span class="n">LGBMRegressor</span><span class="p">(</span>
                <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span> <span class="n">objective</span><span class="o">=</span><span class="s">"quantile"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">quantile</span>
            <span class="p">),</span>
            <span class="n">cv</span><span class="o">=</span><span class="n">KFold</span><span class="p">(</span><span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">),</span>
            <span class="n">param_distributions</span><span class="o">=</span><span class="p">{</span>
                <span class="s">"learning_rate"</span><span class="p">:</span> <span class="n">loguniform</span><span class="p">.</span><span class="n">rvs</span><span class="p">(</span>
                    <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">1000</span>
                <span class="p">),</span>
                <span class="s">"n_estimators"</span><span class="p">:</span> <span class="p">[</span><span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">200</span><span class="p">],</span>
                <span class="s">"num_leaves"</span><span class="p">:</span> <span class="p">[</span><span class="mi">31</span><span class="p">,</span> <span class="mi">63</span><span class="p">,</span> <span class="mi">127</span><span class="p">],</span>
            <span class="p">},</span>
            <span class="n">scoring</span><span class="o">=</span><span class="n">quantile_scorer</span><span class="p">,</span>
            <span class="n">n_iter</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">n_iter_cv</span><span class="p">,</span>
            <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span>
            <span class="n">n_jobs</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">check_X_y</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
        <span class="n">sample_weight</span> <span class="o">=</span> <span class="n">_check_sample_weight</span><span class="p">(</span><span class="n">sample_weight</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>

        <span class="p">(</span>
            <span class="n">X_train</span><span class="p">,</span>
            <span class="n">X_cal</span><span class="p">,</span>
            <span class="n">y_train</span><span class="p">,</span>
            <span class="n">y_cal</span><span class="p">,</span>
            <span class="n">sample_weight_train</span><span class="p">,</span>
            <span class="n">sample_weight_cal</span><span class="p">,</span>
        <span class="p">)</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
            <span class="n">X</span><span class="p">,</span>
            <span class="n">y</span><span class="p">,</span>
            <span class="n">sample_weight</span><span class="p">,</span>
            <span class="n">test_size</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">calibration_size</span><span class="p">,</span>
            <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span>
        <span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">model_lower_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_return_quantile_model</span><span class="p">(</span><span class="n">quantile</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">/</span> <span class="mi">2</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span>
            <span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight_train</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">model_upper_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_return_quantile_model</span><span class="p">(</span>
            <span class="n">quantile</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">/</span> <span class="mi">2</span>
        <span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight_train</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">y_cal_conformity_scores_</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">model_lower_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_cal</span><span class="p">)</span> <span class="o">-</span> <span class="n">y_cal</span><span class="p">,</span>
            <span class="n">y_cal</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">model_upper_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_cal</span><span class="p">),</span>
        <span class="p">)</span>
        <span class="n">wq</span> <span class="o">=</span> <span class="n">DescrStatsW</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">y_cal_conformity_scores_</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">sample_weight_cal</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">quantile_conformity_scores_</span> <span class="o">=</span> <span class="n">wq</span><span class="p">.</span><span class="n">quantile</span><span class="p">(</span>
            <span class="n">probs</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span><span class="p">,</span> <span class="n">return_pandas</span><span class="o">=</span><span class="bp">False</span>
        <span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

        <span class="k">return</span> <span class="bp">self</span>

    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="s">"""
        Predicts conformalized quantile regression intervals for X.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The input samples.

        Returns
        -------
        y_test_interval_pred_cqr : ndarray of shape (n_samples, 2)
            Returns the predicted lower and upper bound for each sample in X.
        """</span>
        <span class="n">check_is_fitted</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
        <span class="n">y_test_interval_pred_cqr</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">column_stack</span><span class="p">(</span>
            <span class="p">[</span>
                <span class="bp">self</span><span class="p">.</span><span class="n">model_lower_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">quantile_conformity_scores_</span><span class="p">,</span>
                <span class="bp">self</span><span class="p">.</span><span class="n">model_upper_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">quantile_conformity_scores_</span><span class="p">,</span>
            <span class="p">]</span>
        <span class="p">)</span>
        <span class="k">return</span> <span class="n">y_test_interval_pred_cqr</span>
</code></pre></div></div>

<hr />

<h2 id="using-the-t-learner">Using the T-learner</h2>

<p><div align="justify">In this example, we will utilize the T-learner technique <a href="#bibliography">[3, 9]</a>, building a model to estimate each $\mathbb{E}(Y|do(T=t), Z)$ for $t\in\{0, 1\}$. We will set <code>alpha=0.05</code> to construct prediction sets with 95% coverage.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model_t0</span> <span class="o">=</span> <span class="n">ConformalizedQuantileRegression</span><span class="p">(</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">n_iter_cv</span><span class="o">=</span><span class="mi">30</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">XZ_train_t0</span><span class="p">,</span> <span class="n">y_train_t0</span><span class="p">)</span>
<span class="n">y_test_interval_pred_cqr_t0</span> <span class="o">=</span> <span class="n">model_t0</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">XZ_test</span><span class="p">)</span>

<span class="n">model_t1</span> <span class="o">=</span> <span class="n">ConformalizedQuantileRegression</span><span class="p">(</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">n_iter_cv</span><span class="o">=</span><span class="mi">30</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">XZ_train_t1</span><span class="p">,</span> <span class="n">y_train_t1</span><span class="p">)</span>
<span class="n">y_test_interval_pred_cqr_t1</span> <span class="o">=</span> <span class="n">model_t1</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">XZ_test</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">$\oint$ <em>It's worth noting that you may want to implement an importance weighting strategy here to achieve a better prediction set in regions where $P(T=t | Z=z)$ is close to zero (naturally, these being regions with fewer examples). We can interpret this as being in a <a href="https://vitaliset.github.io/covariate-shift-0-introduction/">covariate shift</a> environment, where the covariates of the population to which we are applying the model are different from those of the population on which we are training it. However, if you can ensure the positivity assumption, it may be less critical (especially with models that don't underfit, such as tree ensembles <a href="#bibliography">[10]</a>).</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">return_sample_weight_treatment_i</span><span class="p">(</span><span class="n">df_train</span><span class="p">,</span> <span class="n">df_test</span><span class="p">):</span>
    <span class="n">df_ood_ti</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span>
        <span class="p">[</span>
            <span class="n">df</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">train_or_test</span><span class="o">=</span><span class="n">j</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">df</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
                <span class="p">[</span>
                    <span class="n">df_train</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"target_cf"</span><span class="p">]),</span>
                    <span class="n">df_test</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"target_cf"</span><span class="p">]),</span>
                <span class="p">]</span>
            <span class="p">)</span>
        <span class="p">]</span>
    <span class="p">)</span>

    <span class="n">ood_sample_correction_ti</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span>
        <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span>
    <span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">df_ood_ti</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"train_or_test"</span><span class="p">]),</span> <span class="n">df_ood_ti</span><span class="p">.</span><span class="n">train_or_test</span><span class="p">)</span>

    <span class="n">roc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span>
        <span class="n">df_ood_ti</span><span class="p">.</span><span class="n">train_or_test</span><span class="p">,</span>
        <span class="n">ood_sample_correction_ti</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span>
            <span class="n">df_ood_ti</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"train_or_test"</span><span class="p">])</span>
        <span class="p">)[:,</span> <span class="mi">1</span><span class="p">],</span>
    <span class="p">)</span>

    <span class="n">probs</span> <span class="o">=</span> <span class="n">ood_sample_correction_ti</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span>
        <span class="n">df_train</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"target_cf"</span><span class="p">])</span>
    <span class="p">)</span>
    <span class="c1"># Equivalent to `probs[:, 1]/probs[:, 0]`.
</span>    <span class="n">sample_weights_ti</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>

    <span class="k">return</span> <span class="n">roc</span><span class="p">,</span> <span class="n">sample_weights_ti</span>

<span class="n">_</span><span class="p">,</span> <span class="n">sw_0</span> <span class="o">=</span> <span class="n">return_sample_weight_treatment_i</span><span class="p">(</span><span class="n">df_train</span><span class="o">=</span><span class="n">df_train_t0</span><span class="p">,</span> <span class="n">df_test</span><span class="o">=</span><span class="n">df_test</span><span class="p">)</span>
</code></pre></div></div>

<hr />

<h2 id="evaluating-the-conformal-regression">Evaluating the conformal regression</h2>

<p><div align="justify">With the interval estimates calculated in <code>y_test_interval_pred_cqr_t0</code> and <code>y_test_interval_pred_cqr_t1</code>, we can assess the effectiveness of our predictions. To do this, we will examine factors such as the coverage of our predictions, in both the observational and counterfactual scenarios (given that we also have this value for evaluation) and the size of these intervals.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df_val</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_test</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">pred_lower_t_0</span><span class="o">=</span><span class="n">y_test_interval_pred_cqr_t0</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">pred_upper_t_0</span><span class="o">=</span><span class="n">y_test_interval_pred_cqr_t0</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">ic_size_t_0</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_0</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_0</span><span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">pred_lower_t_1</span><span class="o">=</span><span class="n">y_test_interval_pred_cqr_t1</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">pred_upper_t_1</span><span class="o">=</span><span class="n">y_test_interval_pred_cqr_t1</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">ic_size_t_1</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_1</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_1</span><span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span>
        <span class="n">prob</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">positivity_assumption_check_estimator</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span>
            <span class="n">df_</span><span class="p">.</span><span class="nb">filter</span><span class="p">(</span><span class="n">backdoor_set</span><span class="p">)</span>
        <span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
    <span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">prob_cut</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">pd</span><span class="p">.</span><span class="n">cut</span><span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">prob</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">)))</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span>
        <span class="n">coverage</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">select</span><span class="p">(</span>
            <span class="n">condlist</span><span class="o">=</span><span class="p">[</span><span class="n">df_</span><span class="p">.</span><span class="n">treatment</span> <span class="o">==</span> <span class="mi">0</span><span class="p">],</span>
            <span class="n">choicelist</span><span class="o">=</span><span class="p">[</span>
                <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target</span> <span class="o">&gt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target</span> <span class="o">&lt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_0</span><span class="p">)</span>
            <span class="p">],</span>
            <span class="n">default</span><span class="o">=</span><span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target</span> <span class="o">&gt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_1</span><span class="p">)</span>
            <span class="o">&amp;</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target</span> <span class="o">&lt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_1</span><span class="p">),</span>
        <span class="p">)</span>
    <span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span>
        <span class="n">coverage_cf</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">select</span><span class="p">(</span>
            <span class="n">condlist</span><span class="o">=</span><span class="p">[</span><span class="n">df_</span><span class="p">.</span><span class="n">treatment</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">],</span>
            <span class="n">choicelist</span><span class="o">=</span><span class="p">[</span>
                <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span> <span class="o">&gt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_0</span><span class="p">)</span>
                <span class="o">&amp;</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span> <span class="o">&lt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_0</span><span class="p">)</span>
            <span class="p">],</span>
            <span class="n">default</span><span class="o">=</span><span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span> <span class="o">&gt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_1</span><span class="p">)</span>
            <span class="o">&amp;</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span> <span class="o">&lt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_1</span><span class="p">),</span>
        <span class="p">)</span>
    <span class="p">)</span>
<span class="p">)</span>

<span class="n">df_val</span><span class="p">.</span><span class="n">coverage</span><span class="p">.</span><span class="n">mean</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.9497
</code></pre></div></div>

<p><div align="justify">It's important to highlight that conformal prediction ensures us marginal coverage, which doesn't always convert into conditional coverage <a href="#bibliography">[1]</a>. We could be generating excellent estimates for certain regions of $Z$ and inferior ones for the rest and still have good marginal coverage because they would cancel out. To examine this, we would need to study</div></p>

\[P((Y|Z=z)\in \tau(Z=z) \,|\, T=t, Z=z),\]

<p><div align="justify">where $\tau(Z=z)$ is the prediction set for $Z=z$.</div></p>

<p><div align="justify">One method to visualize this is by partitioning, for instance, the regions using $P(T=1 | Z=z)$ (from the same model as used in the positivity assumption check) to construct buckets where we can calculate coverage estimates, i.e., the mean of $(Y|Z=z)\in \tau(Z=z)$. If we further break it down by treatment, we will be measuring something similar to the conditional coverage.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">bootstrap</span>

<span class="k">def</span> <span class="nf">bootstrap_ci</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ci</span><span class="o">=</span><span class="mf">0.95</span><span class="p">):</span>
    <span class="n">boot</span> <span class="o">=</span> <span class="n">bootstrap</span><span class="p">((</span><span class="n">x</span><span class="p">,),</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="n">confidence_level</span><span class="o">=</span><span class="n">ci</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">boot</span><span class="p">.</span><span class="n">confidence_interval</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>

<span class="n">df_val_cond_aux1</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">groupby</span><span class="p">([</span><span class="s">"prob_cut"</span><span class="p">,</span> <span class="s">"treatment"</span><span class="p">])</span>
    <span class="p">.</span><span class="n">coverage</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">bootstrap_ci</span><span class="p">)</span>
    <span class="p">.</span><span class="n">to_frame</span><span class="p">()</span>
    <span class="p">.</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="s">"coverage"</span><span class="p">:</span> <span class="s">"coverage_confidence_interval"</span><span class="p">})</span>
<span class="p">)</span>

<span class="n">df_val_cond_aux2</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">groupby</span><span class="p">([</span><span class="s">"prob_cut"</span><span class="p">,</span> <span class="s">"treatment"</span><span class="p">])</span>
    <span class="p">.</span><span class="n">coverage_cf</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">bootstrap_ci</span><span class="p">)</span>
    <span class="p">.</span><span class="n">to_frame</span><span class="p">()</span>
    <span class="p">.</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="s">"coverage_cf"</span><span class="p">:</span> <span class="s">"coverage_cf_confidence_interval"</span><span class="p">})</span>
<span class="p">)</span>

<span class="n">df_val_cond_aux3</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">groupby</span><span class="p">([</span><span class="s">"prob_cut"</span><span class="p">,</span> <span class="s">"treatment"</span><span class="p">])</span>
    <span class="p">.</span><span class="n">agg</span><span class="p">(</span>
        <span class="p">{</span>
            <span class="s">"coverage"</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span>
            <span class="s">"coverage_cf"</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span>
            <span class="s">"ic_size_t_0"</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span>
            <span class="s">"ic_size_t_1"</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span>
        <span class="p">}</span>
    <span class="p">)</span>
    <span class="p">.</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="k">lambda</span> <span class="n">col</span><span class="p">:</span> <span class="n">col</span> <span class="o">+</span> <span class="s">"_mean"</span><span class="p">)</span>
<span class="p">)</span>

<span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span>
    <span class="p">[</span><span class="n">df_val_cond_aux1</span><span class="p">,</span> <span class="n">df_val_cond_aux2</span><span class="p">,</span> <span class="n">df_val_cond_aux3</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span>
<span class="p">).</span><span class="n">reset_index</span><span class="p">().</span><span class="n">sort_values</span><span class="p">([</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"prob_cut"</span><span class="p">])</span>

</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>prob_cut</th>
      <th>treatment</th>
      <th>coverage_confidence_interval</th>
      <th>coverage_cf_confidence_interval</th>
      <th>coverage_mean</th>
      <th>coverage_cf_mean</th>
      <th>ic_size_t_0_mean</th>
      <th>ic_size_t_1_mean</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>(0.0, 0.2]</td>
      <td>0.0</td>
      <td>[0.93669, 0.95635]</td>
      <td>[0.93046, 0.95108]</td>
      <td>0.947242</td>
      <td>0.941487</td>
      <td>1.980407</td>
      <td>6.920811</td>
    </tr>
    <tr>
      <th>2</th>
      <td>(0.2, 0.4]</td>
      <td>0.0</td>
      <td>[0.94852, 0.96972]</td>
      <td>[0.93565, 0.95988]</td>
      <td>0.959879</td>
      <td>0.948524</td>
      <td>2.032159</td>
      <td>4.205569</td>
    </tr>
    <tr>
      <th>4</th>
      <td>(0.4, 0.6]</td>
      <td>0.0</td>
      <td>[0.93891, 0.96946]</td>
      <td>[0.94024, 0.96946]</td>
      <td>0.956175</td>
      <td>0.956175</td>
      <td>2.054778</td>
      <td>2.642784</td>
    </tr>
    <tr>
      <th>6</th>
      <td>(0.6, 0.8]</td>
      <td>0.0</td>
      <td>[0.92321, 0.96071]</td>
      <td>[0.94464, 0.97679]</td>
      <td>0.944643</td>
      <td>0.962500</td>
      <td>1.949311</td>
      <td>4.066495</td>
    </tr>
    <tr>
      <th>8</th>
      <td>(0.8, 1.0]</td>
      <td>0.0</td>
      <td>[0.91579, 0.96842]</td>
      <td>[0.92982, 0.97895]</td>
      <td>0.947368</td>
      <td>0.957895</td>
      <td>2.180449</td>
      <td>6.671260</td>
    </tr>
    <tr>
      <th>1</th>
      <td>(0.0, 0.2]</td>
      <td>1.0</td>
      <td>[0.90459, 0.96466]</td>
      <td>[0.92226, 0.97527]</td>
      <td>0.939929</td>
      <td>0.954064</td>
      <td>1.985375</td>
      <td>6.573730</td>
    </tr>
    <tr>
      <th>3</th>
      <td>(0.2, 0.4]</td>
      <td>1.0</td>
      <td>[0.8998, 0.94499]</td>
      <td>[0.91749, 0.95874]</td>
      <td>0.925344</td>
      <td>0.941061</td>
      <td>2.043951</td>
      <td>4.025180</td>
    </tr>
    <tr>
      <th>5</th>
      <td>(0.4, 0.6]</td>
      <td>1.0</td>
      <td>[0.93103, 0.96296]</td>
      <td>[0.93103, 0.96296]</td>
      <td>0.948914</td>
      <td>0.948914</td>
      <td>2.054981</td>
      <td>2.636064</td>
    </tr>
    <tr>
      <th>7</th>
      <td>(0.6, 0.8]</td>
      <td>1.0</td>
      <td>[0.94165, 0.96353]</td>
      <td>[0.9329, 0.95697]</td>
      <td>0.953319</td>
      <td>0.945295</td>
      <td>1.973618</td>
      <td>4.291990</td>
    </tr>
    <tr>
      <th>9</th>
      <td>(0.8, 1.0]</td>
      <td>1.0</td>
      <td>[0.94, 0.95902]</td>
      <td>[0.94049, 0.95951]</td>
      <td>0.950244</td>
      <td>0.950732</td>
      <td>2.214706</td>
      <td>6.862401</td>
    </tr>
  </tbody>
</table>
</div>

<p><div align="justify">Indeed, it appears that we're also doing a reasonable job in terms of our conditional coverage, very close to 95%, the coverage requested from <code>ConformalizedQuantileRegression</code>. This implies that even in regions with fewer examples with treatment $T=0$ (for instance, where <code>prob_cut=[0.8, 1)</code>), our coverage is fairly substantial.</div></p>

<p><div align="justify">$\oint$ <em>Since $P((Y|Z=z)\in \tau(Z=z) \,|\, T=t, Z=z)$ shares many characteristics of a classification problem, another viable strategy might be to explore what the probabilistic output of a classifier, tasked with predicting the coverage, would yield.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">probs_coverage</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">LogisticRegression</span><span class="p">()</span>
    <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">df_val</span><span class="p">.</span><span class="nb">filter</span><span class="p">([</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"confounder"</span><span class="p">]),</span> <span class="n">df_val</span><span class="p">.</span><span class="n">coverage</span><span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
    <span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">df_val</span><span class="p">.</span><span class="nb">filter</span><span class="p">([</span><span class="s">"treatment"</span><span class="p">,</span> <span class="s">"confounder"</span><span class="p">]))[:,</span> <span class="mi">1</span><span class="p">]</span>
<span class="p">)</span>

<span class="n">roc_auc_score</span><span class="p">(</span><span class="n">df_val</span><span class="p">.</span><span class="n">coverage</span><span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span> <span class="n">probs_coverage</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.5152119817684396
</code></pre></div></div>

<p><div align="justify"><em>By executing this, we can observe that the classifier is incapable of identifying regions where there is poor coverage. We can see that the minimum of these estimated conditional probabilities (without extensive verification of calibration) remains reasonably high.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">min</span><span class="p">(</span><span class="n">probs_coverage</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">probs_coverage</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(0.9405329900858612, 0.9577096822516356)
</code></pre></div></div>

<p><div align="justify">$\oint$ <em>It's also common to evaluate the conditional coverage in relation to the size of the predicted interval (partitioning the intervals into &quot;small&quot;, &quot;medium&quot;, and &quot;large&quot;) <a href="#bibliography">[1]</a>. In a real application, I would undertake this, but I wish to avoid overloading this text with code, so the above already illustrates the exercise adequately.</em></div></p>

<hr />

<h2 id="joining-confidence-intervals">Joining confidence intervals</h2>

<p><div align="justify">While our estimates appear to be coherent, what we ultimately aim to estimate is what happens when we subtract the predicted intervals. Combining intervals while maintaining coverage isn't a straightforward task. Let's delve into this scenario a bit more.</div></p>

<p><div align="justify">Let's assume we have two random variables with given probabilities of being within certain intervals: </div></p>

\[\mathbb{P}(A \in (m_a, M_a)) \geq 1 - \alpha, \mathbb{P}(B \in (m_b, M_b)) \geq 1 - \beta.\]

<p><div align="justify">Observe that the intersection of these two events implies that the sum of the random variables lies within the interval derived from the summation of the ends of the intervals. In other words,</div></p>

\[\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\} \subset  \{A + B \in (m_a + m_b, M_a + M_b)\}.\]

<p><div align="justify">In probability theory, a set contained in another is bounded by the probability of the larger set, so</div></p>

\[\mathbb{P}(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}) \leq \mathbb{P}(\{A + B \in (m_a + m_b, M_a + M_b)\}).\]

<p><div align="justify">From here, let's develop an inequality starting from the left term. The probability of the complement can be calculated as</div></p>

\[\begin{align*}
    \mathbb{P}(\left(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}\right)^C) &amp;= \mathbb{P}(\{A \in (m_a, M_a)\}^C \cup \{ B \in (m_b, M_b)\}^C)\\
    &amp;\leq \mathbb{P}(\{A \in (m_a, M_a)\}^C) + \mathbb{P}(\{ B \in (m_b, M_b)\}^C),
\end{align*}\]

<p><div align="justify">using De Morgan's laws and an overestimation of the probability of the union as the sum of the probabilities.</div></p>

<p><div align="justify">Following this, we can conclude that</div></p>

\[\begin{align*}
    \mathbb{P}(\left(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}\right)^C) &amp;\leq 1 - \mathbb{P}(\{A \in (m_a, M_a)\}) + 1 - \mathbb{P}(\{ B \in (m_b, M_b)\})\\
    &amp;\leq 1 - (1 - \alpha) + 1 - (1 - \beta) = \alpha + \beta.
\end{align*}\]

<p><div align="justify">$\oint$ <em>This inequality is loose because  $\{A \in (m_a, M_a)\}^C $ and $ \{ B \in (m_b, M_b)\}^C$ have a significant intersection. However, we assume it's zero when we overestimate the probability of the union by the sum of the probabilities (we are presuming they are disjoint intervals)</em>.</div></p>

<p><div align="justify">Since</div></p>

\[\mathbb{P}(\left(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}\right)^C) \leq \alpha + \beta,\]

<p><div align="justify">we find</div></p>

\[\mathbb{P}(\left(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}\right)) = 1 - \mathbb{P}(\left(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}\right)^C) \geq 1 - (\alpha + \beta).\]

<p><div align="justify">From this, we can deduce that since</div></p>

\[\mathbb{P}(\{A \in (m_a, M_a)\} \cap \{ B \in (m_b, M_b)\}) \leq  \mathbb{P}(\{A + B \in (m_a + m_b, M_a + M_b)\}),\]

<p><div align="justify">we obtain an inequality for the interval resulting from the sum of the ends of the initial intervals:</div></p>

\[\mathbb{P}(\{A + B \in (m_a + m_b, M_a + M_b)\}) \geq 1 - (\alpha + \beta).\]

<p><div align="justify">$\oint$ <em>This method is generally used in hypothesis testing with a Bonferroni correction derived from Boole's inequality <a href="#bibliography">[11]</a>.</em></div></p>

<hr />

<h2 id="prediction-interval-of-cate">Prediction interval of CATE</h2>

<p><div align="justify">In our particular scenario, we are working with $A = \mathbb{E}(Y|do(T=1), Z=z)$ and $B = - \mathbb{E}(Y|do(T=0), Z=z)$. As a result, the limits of the intervals for $B$ are flipped from the ones we have in <code>y_test_interval_pred_cqr_t0</code>.</div></p>

<p><div align="justify">Once again, it would be valuable to assess the coverage and size of the intervals that we have now created.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df_val_cate</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span>
        <span class="n">cate_actual</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">select</span><span class="p">(</span>
            <span class="n">condlist</span><span class="o">=</span><span class="p">[(</span><span class="n">df_</span><span class="p">.</span><span class="n">treatment</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)],</span>
            <span class="n">choicelist</span><span class="o">=</span><span class="p">[</span><span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">target</span><span class="p">],</span>
            <span class="n">default</span><span class="o">=</span><span class="p">[</span><span class="n">df_</span><span class="p">.</span><span class="n">target</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">target_cf</span><span class="p">],</span>
        <span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    <span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">cate_ci_lower</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_1</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_0</span><span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">cate_ci_upper</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_upper_t_1</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">pred_lower_t_0</span><span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">cate_ci_size</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="n">df_</span><span class="p">.</span><span class="n">cate_ci_upper</span> <span class="o">-</span> <span class="n">df_</span><span class="p">.</span><span class="n">cate_ci_lower</span><span class="p">)</span>
    <span class="p">.</span><span class="n">assign</span><span class="p">(</span>
        <span class="n">coverage_cate</span><span class="o">=</span><span class="k">lambda</span> <span class="n">df_</span><span class="p">:</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">cate_actual</span> <span class="o">&gt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">cate_ci_lower</span><span class="p">)</span>
        <span class="o">&amp;</span> <span class="p">(</span><span class="n">df_</span><span class="p">.</span><span class="n">cate_actual</span> <span class="o">&lt;</span> <span class="n">df_</span><span class="p">.</span><span class="n">cate_ci_upper</span><span class="p">)</span>
    <span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">As expected, the prediction intervals are larger than the ones found earlier.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
<span class="n">aux_hist</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_0</span><span class="p">,</span> <span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_1</span><span class="p">])</span>
<span class="n">min_hist</span><span class="p">,</span> <span class="n">max_hist</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">aux_hist</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">aux_hist</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_0</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_hist</span><span class="p">,</span> <span class="n">max_hist</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span>
    <span class="n">weights</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_0</span><span class="p">)</span> <span class="o">/</span> <span class="n">df_val</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_1</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_hist</span><span class="p">,</span> <span class="n">max_hist</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span>
    <span class="n">weights</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">df_val</span><span class="p">.</span><span class="n">ic_size_t_1</span><span class="p">)</span> <span class="o">/</span> <span class="n">df_val</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">].</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">df_val_cate</span><span class="p">.</span><span class="n">cate_ci_size</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
    <span class="n">weights</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">df_val_cate</span><span class="p">.</span><span class="n">cate_ci_size</span><span class="p">)</span> <span class="o">/</span> <span class="n">df_val_cate</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span>
    <span class="s">"Histogram of interval size for $\mathbb{E}(Y | do(T=0), Z=z)$"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="s">"medium"</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span>
    <span class="s">"Histogram of interval size for $\mathbb{E}(Y | do(T=1), Z=z)$"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="s">"medium"</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Histogram of interval size for CATE(Z=z)"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="s">"medium"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/cqr_cate/output_36_0.png" /></center></div></p>

<p><div align="justify">Even though our individual prediction intervals were constructed for a coverage of $1 - \alpha = 0.95$, our prediction intervals for the CATE should only be $1 - (0.05 + 0.05) = 0.9$. However, as we discussed before, this is a loose approximation, and the actual coverage is substantially better than that.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df_val_cate</span><span class="p">.</span><span class="n">coverage_cate</span><span class="p">.</span><span class="n">mean</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.9997
</code></pre></div></div>

<p><div align="justify">Given that we are dealing with $Z\in\mathbb{R}$, we can visually evaluate our conformal estimator by plotting the prediction intervals for the meta-estimators and for our estimate of the CATE. In addition, since we have control over the noise variance, we can also plot the real 95% confidence interval.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">confounder_plot</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">XZ_test</span><span class="p">.</span><span class="n">confounder</span><span class="p">.</span><span class="nb">min</span><span class="p">(),</span> <span class="n">XZ_test</span><span class="p">.</span><span class="n">confounder</span><span class="p">.</span><span class="nb">max</span><span class="p">(),</span> <span class="mi">10_001</span><span class="p">)</span>
<span class="n">ci_t1_plot</span> <span class="o">=</span> <span class="n">model_t1</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)).</span><span class="n">T</span>
<span class="n">ci_t0_plot</span> <span class="o">=</span> <span class="n">model_t0</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)).</span><span class="n">T</span>
<span class="n">ci_cate_plot</span> <span class="o">=</span> <span class="n">ci_t1_plot</span> <span class="o">-</span> <span class="n">ci_t0_plot</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">,]</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>

<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="n">func_0</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Real confidence interval for $\mathbb{E}(Y | do(T=0), Z=z)$"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">,</span> <span class="n">func_0</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"C0"</span><span class="p">)</span>

<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="n">func_1</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">))),</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"C1"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Real confidence interval for $\mathbb{E}(Y | do(T=1), Z=z)$"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="n">func_1</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">))),</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"C1"</span><span class="p">,</span>
<span class="p">)</span>

<span class="c1"># Variance of CATE(Z=z) is 0.5 * |z| because the term
# related to 1 u_Y is annulled when we do
# \mathbb{E}(g_Y(u_Y, z, 1)) - \mathbb{E}(g_Y(u_Y, z, 0)).
</span><span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="n">func_1</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span>
    <span class="o">-</span> <span class="n">func_0</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span>
    <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">))),</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"C2"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Confidence interval for CATE(Z=z)"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="n">func_1</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span>
    <span class="o">-</span> <span class="n">func_0</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">)</span>
    <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">confounder_plot</span><span class="p">))),</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"C2"</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="o">*</span><span class="n">ci_t0_plot</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Prediction interval for $\mathbb{E}(Y | do(T=0), Z=z)$"</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="o">*</span><span class="n">ci_t1_plot</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Prediction interval for $\mathbb{E}(Y | do(T=1), Z=z)$"</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C1"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">confounder_plot</span><span class="p">,</span>
    <span class="o">*</span><span class="n">ci_cate_plot</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Prediction interval for CATE(Z=z)"</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C2"</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"z"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/cqr_cate/output_40_0.png" /></center></div></p>

<p><div align="justify">In fact, all our prediction intervals seem to align closely with the theoretical value of the confidence intervals, with the exception of the CATE interval, where we are overestimating it.</div></p>

<hr />

<h2 id="final-considerations">Final considerations</h2>

<p><div align="justify">The CATE is an extremely interesting quantity to have in various scenarios of applied causal inference. The ability to integrate the concepts of conformal prediction into CATE estimation serves as a powerful tool, ensuring that we leverage the full potential of uncertainty quantification in our analyses and decisions. In this exploration, Conformalized Quantile Regression demonstrated its aptitude as a robust method for estimating the CATE while also offering reliable uncertainty quantification despite some overestimation.</div></p>

<p><div align="justify">$\oint$ <em>After writing this post, I took a closer look at the discussions connecting causal inference with conformal predictions and found the article <a href="https://arxiv.org/abs/2006.06138">Conformal Inference of Counterfactuals and Individual Treatment Effects</a> very interesting. There, they also experiment with variations of CQR, but with the doubly robust estimator. They seem to pay special attention to the scenario of conformal prediction with covariate shift — the exact scenario we are addressing here — and demonstrate heightened caution when deploying CQR in this context. In this post, I only implemented a <code>sample_weight</code> that is also used when calculating the quantiles of the conformal prediction calibration set.</em></div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p><div align="justify">[1] <a href="http://people.eecs.berkeley.edu/~angelopoulos/blog/posts/gentle-intro/">A Gentle Introduction to Conformal Prediction and Distribution-Free Uncertainty Quantification. Anastasios N. Angelopoulos, Stephen Bates.</a></div></p>

<p><div align="justify">[2] <a href="https://github.com/rbstern/causality_book/blob/435e920d7d68872fea1be187b0dcf6c5e8b3a55e/book.pdf">Class notes on Causal Inference (PTBR). Rafael Bassi Stern.</a></div></p>

<p><div align="justify">[3] <a href="https://matheusfacure.github.io/python-causality-handbook/landing-page.html">Causal Inference for The Brave and True. Matheus Facure</a></div></p>

<p><div align="justify">[4] <a href="https://youtube.com/playlist?list=PLoazKTcS0RzZ1SUgeOgc6SWt51gfT80N0">Causal Inference Course. Brady Neal.</a></div></p>

<p><div align="justify">[5] <a href="https://matheusfacure.github.io/python-causality-handbook/landing-page.html">Causal Inference on Observational Data: It&#39;s All About the Assumptions. Jean-Yves Gérardy.</a></div></p>

<p><div align="justify">[6] <a href="https://scikit-learn.org/stable/modules/calibration.html">Probability calibration. Scikit-Learn User Guide.</a></div></p>

<p><div align="justify">[7] <a href="https://arxiv.org/abs/1102.2101">Estimating conditional quantiles with the help of the pinball loss. Ingo Steinwart, Andreas Christmann.</a></div></p>

<p><div align="justify">[8] <a href="https://towardsdatascience.com/how-to-predict-risk-proportional-intervals-with-conformal-quantile-regression-175775840dc4">How to Predict Risk-Proportional Intervals with Conformal Quantile Regression. Samuele Mazzanti.</a></div></p>

<p><div align="justify">[9] <a href="https://statisticaloddsandends.wordpress.com/2022/05/20/t-learners-s-learners-and-x-learners/">T-learners, S-learners and X-learners. Statistical Odds &amp; Ends.</a></div></p>

<p><div align="justify">[10] <a href="https://matheusfacure.github.io/python-causality-handbook/landing-page.html">Analysis of Kernel Mean Matching under Covariate Shift. Yaoliang Yu, Csaba Szepesvari.</a></div></p>

<p><div align="justify">[11] <a href="https://en.wikipedia.org/wiki/Bonferroni_correction">Bonferroni correction. Wikipedia.</a></div></p>

<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/cqr_cate">repository of this post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;uncertainty quantification&quot;, &quot;causal inference&quot;]" /><summary type="html"><![CDATA[Applying conformalized quantile regression in an important causal inference task.]]></summary></entry><entry><title type="html">Conditional Density Estimation</title><link href="https://vitaliset.github.io/conditional-density-estimation/" rel="alternate" type="text/html" title="Conditional Density Estimation" /><published>2023-06-16T00:00:00+00:00</published><updated>2023-06-16T00:00:00+00:00</updated><id>https://vitaliset.github.io/conditional-density-estimation</id><content type="html" xml:base="https://vitaliset.github.io/conditional-density-estimation/"><![CDATA[<p><div align="justify">Typically, when we seek to model the relationship between a target variable $Y\in\mathbb{R}$ and one or more covariates $X$, our goal is to establish a conditional-expectation type association. Mathematically, if we define our loss as the mean squared error, our explicit aim is to identify the function $\mathbb{E} \left( Y \,|\, X=x\right)$. This function intuitively gives a prediction of the average value of $Y$ given that the covariates are $X=x$. Despite the straightforward and simplified summary provided by point estimates, they often fail to encapsulate the inherent intricacies and uncertainties prevalent in most real-world predictive scenarios. This prompts us to ask: Is the variance around this average value extensive, or can we confidently anticipate the value to be in close proximity to the predicted one?</div></p>

<p><div align="justify">Diverging from the conventional approach of a single point estimation, Conditional Density Estimation (CDE) aims to understand the plausibility of an entire range of potential outcomes given specific input data. In mathematical terms, we are estimating the probability density function $f \left( y \,|\, X=x \right)$.</div></p>

<p><div align="justify">The holistic nature of CDE affords a deeper understanding of data characteristics and proves beneficial in addressing two fundamental aspects: evaluating model trustworthiness and accommodating multi-modal outcomes.</div></p>

<ol>
  <li>
    <p><div align="justify"><strong>Model trustworthiness:</strong> Unlike point estimation predictions, which offer no insight into their own reliability or uncertainty, CDE provides a full distribution of potential outcomes, thereby inherently conveying information about prediction confidence. The variance of the predicted distribution can act as a measure of uncertainty or confidence, affording users a more comprehensive understanding of the predictions. Such an understanding proves critical when making decisions based on these predictions. For instance, in the healthcare sector, a prediction about patient outcomes accompanied by an understanding of its confidence or uncertainty could lead to more informed and suitable medical decisions.</div></p>
  </li>
  <li>
    <p><div align="justify"><strong>Multi-modal outcomes:</strong> Traditional regression or classification problems, generally focused on single point predictions, often fall short in capturing the full complexity of real-world phenomena. This shortfall becomes particularly apparent when a single input could feasibly yield multiple valid outputs, a situation termed multi-modality. Consider a task of predicting salary based on certain features, but we're unsure if the individual resides in a state with a high or low average salary. In such a context, a more nuanced salary estimate shouldn't merely be an average drawn from both regions. Rather, it would be more fitting to present a bi-modal distribution with two distinct peaks. Each peak would denote a plausible salary range for the individual, depending on whether they live in one state or another.</div></p>
  </li>
</ol>

<p><div align="justify">$\oint$ <em>The field of conformal predictions aims to address this uncertainty by estimating prediction sets $\tau(X=x)$, such that $\mathbb{P}\left(\left(Y\,|\,X=x\right) \in \tau(X=x)\right) \geq 1 - \alpha$ with a certain desired coverage $\alpha$ [<a href="#bibliography">1</a>]. Interpreting the prediction sets, for instance by inspecting their size, begins to address some of the queries we raised earlier. However, in regression tasks, the prediction set is usually framed as an interval. Having only the interval extremes, which naturally attempt to estimate conditional quantiles, does not fully portray the uncertainty associated with the prediction. This limitation is particularly evident when dealing with multi-modal densities. Or, if you have a utility metric associated with your predictions and aim to examine the average utility for an individual, the logical approach would be to perform an integral on the individual's probability density.</em></div></p>

<hr />

<h2 id="creating-the-dataset">Creating the dataset</h2>

<p><div align="justify">Let's construct a simple illustrative problem to explore the application of non-parametric techniques in the context of CDE. Consider a data generating process of the following form:</div></p>

\[X\sim\textrm{Uniform}(0, 1),\]

\[\left(Y \,|\, X=x\right) \sim \sin\left(2\pi x\right) + \mathcal{N}\left(0, \sigma\left(x\right)\right),\]

<p><div align="justify">where $\sigma(x) = 0.3 - 0.25 \sin(2\pi x)$.</div></p>

<p><div align="justify">In this instance, $X$ is one-dimensional primarily for the purpose of visualization, although our discussion is applicable regardless of the dimensionality of $X$.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">mean_function</span><span class="p">(</span><span class="n">X</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">X</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">deviation_function</span><span class="p">(</span><span class="n">X</span><span class="p">):</span>
    <span class="k">return</span> <span class="mf">0.3</span> <span class="o">+</span> <span class="mf">0.25</span> <span class="o">*</span> <span class="n">mean_function</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">generate_data_with_normal_noise</span><span class="p">(</span>
    <span class="n">mean_generator</span><span class="p">,</span> <span class="n">deviation_generator</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">5_000</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span>
<span class="p">):</span>
    <span class="k">def</span> <span class="nf">normal_noise_generator</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">deviation_generator</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">noise</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span>
            <span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">deviation_generator</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="p">)</span>
        <span class="k">return</span> <span class="n">noise</span>

    <span class="n">rs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state</span><span class="p">).</span><span class="n">randint</span><span class="p">(</span>
        <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">int64</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">2</span>
    <span class="p">)</span>
    <span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">rs</span><span class="p">[</span><span class="mi">0</span><span class="p">]).</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">mean_generator</span><span class="p">(</span><span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">)</span>
    <span class="n">noise</span> <span class="o">=</span> <span class="n">normal_noise_generator</span><span class="p">(</span>
        <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">deviation_generator</span><span class="o">=</span><span class="n">deviation_generator</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">rs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="p">)</span>
    <span class="n">y_pred_noisy</span> <span class="o">=</span> <span class="n">y_pred</span> <span class="o">+</span> <span class="n">noise</span>

    <span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">y_pred_noisy</span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">generate_data_with_normal_noise</span><span class="p">(</span>
    <span class="n">mean_generator</span><span class="o">=</span><span class="n">mean_function</span><span class="p">,</span>
    <span class="n">deviation_generator</span><span class="o">=</span><span class="n">deviation_function</span><span class="p">,</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">By the design of the data, the conditional density is influenced by the covariates in both the mean and the variance.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_grid</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_grid</span><span class="p">,</span> <span class="n">mean_function</span><span class="p">(</span><span class="n">x_grid</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Mean function"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span>
    <span class="n">x_grid</span><span class="p">,</span>
    <span class="n">mean_function</span><span class="p">(</span><span class="n">x_grid</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">deviation_function</span><span class="p">(</span><span class="n">x_grid</span><span class="p">),</span>
    <span class="n">mean_function</span><span class="p">(</span><span class="n">x_grid</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">deviation_function</span><span class="p">(</span><span class="n">x_grid</span><span class="p">),</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"95% confidence interval given x"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Scatter plot of generated data"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_5_0.png" /></center></div></p>

<hr />

<h2 id="histograms">Histograms</h2>

<p><div align="justify">The task of density estimation may initially seem daunting, but in reality, it becomes quite intuitive once we recognize that a histogram (normalized to have an integral of 1) is effectively a technique aimed at achieving this objective. By counting the number of examples in each bin, we discretize the distribution, enabling us to estimate the probability of the regions and thus obtain a "low-resolution" density estimation.</div></p>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_7_0.png" /></center></div></p>

<p><div align="justify">However, employing all samples only yields a density estimate of $Y$ without imposing any condition on $X$.</div></p>

<p><div align="justify">We can easily condition this strategy on $X=x$ by only including points in proximity to $X=x$ when generating the histogram that will represent the conditional density. The definition of "proximity" can be flexible. For instance, we could use a strategy like <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors.radius_neighbors"><code>sklearn.neighbors.NearestNeighbors.radius_neighbors</code></a>, which selects only the examples that reside within a radius of $\varepsilon$ from point $x$, or we could select a fixed number of nearest neighbors using a method like <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors.kneighbors"><code>sklearn.neighbors.NearestNeighbors.kneighbors</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">scipy.stats</span> <span class="kn">import</span> <span class="n">rv_histogram</span><span class="p">,</span> <span class="n">norm</span>

<span class="n">hist</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">51</span><span class="p">))</span>
<span class="n">hist_dist</span> <span class="o">=</span> <span class="n">rv_histogram</span><span class="p">(</span><span class="n">hist</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">plot_conditional_y_using_near_data</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">x_value</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">0.05</span><span class="p">):</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
        <span class="n">y_grid_refined</span><span class="p">,</span>
        <span class="n">norm</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mean_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="n">deviation_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">)).</span><span class="n">pdf</span><span class="p">(</span>
            <span class="n">y_grid_refined</span>
        <span class="p">),</span>
        <span class="s">"--"</span><span class="p">,</span>
        <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span>
        <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"real $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$"</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span>
        <span class="n">y</span><span class="p">[(</span><span class="n">X</span> <span class="o">&lt;</span> <span class="n">x_value</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">X</span> <span class="o">&gt;</span> <span class="n">x_value</span> <span class="o">-</span> <span class="n">eps</span><span class="p">)],</span>
        <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span>
        <span class="n">bins</span><span class="o">=</span><span class="n">y_grid</span><span class="p">,</span>
        <span class="n">density</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
        <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span>
        <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"estimated $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$ using near data"</span><span class="p">,</span>
    <span class="p">)</span>

<span class="n">min_y</span><span class="p">,</span> <span class="n">max_y</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="n">y_grid</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_y</span><span class="p">,</span> <span class="n">max_y</span><span class="p">,</span> <span class="mi">20</span><span class="p">)</span>
<span class="n">y_grid_refined</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_y</span><span class="p">,</span> <span class="n">max_y</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">bar</span><span class="p">(</span><span class="n">y_grid</span><span class="p">,</span> <span class="n">hist_dist</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">y_grid</span><span class="p">),</span> <span class="n">label</span><span class="o">=</span><span class="s">"estimated $f(y)$"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Density of y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">legend</span><span class="p">()</span>

<span class="n">plot_conditional_y_using_near_data</span><span class="p">(</span><span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">x_value</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"C1"</span><span class="p">)</span>
<span class="n">plot_conditional_y_using_near_data</span><span class="p">(</span><span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">x_value</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s">"C2"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Conditional density of y given X=x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_9_0.png" /></center></div></p>

<hr />

<h2 id="kernel-density-estimation">Kernel Density Estimation</h2>

<p><div align="justify">While histograms serve as excellent baselines, they can pose challenges for more intricate distributions. Determining the appropriate number of bins can prove difficult, and we may end up with stair-step functions that aren't the most manageable to work with.</div></p>

<p><div align="justify">In general, the problem of non-parametric density estimation is frequently tackled using Kernel Density Estimation (KDE), and it is logical to use it here too, aligning it with a strategy to convert the problem into a conditional estimation. The essential concept of KDE is to place "bumps" around observed points (shaped like a Gaussian, for instance) and then sum these bumps, normalizing them to yield a density estimate.</div></p>

<p><div align="justify">$\oint$ <em>The nature of the bump (which is called a kernel) and the width (bandwidth) of these bumps are hyperparameters that can be tuned using cross-validation with a likelihood-style metric to assess the likelihood of a test sample having been drawn from your estimated density [<a href="#bibliography">2</a>].</em></div></p>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_11_0.png" /></center></div></p>

<p><div align="justify">To condition our KDE, we can once again use a neighbor search. Utilizing <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html"><code>sklearn.neighbors.NearestNeighbors</code></a> and <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KernelDensity.html"><code>sklearn.neighbors.KernelDensity</code></a> (without being overly concerned about this model's hyperparameters), we can identify the neighbors closest to a specific point, say $X=0.2$, and then estimate the density using these neighbors.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.neighbors</span> <span class="kn">import</span> <span class="n">NearestNeighbors</span><span class="p">,</span> <span class="n">KernelDensity</span>

<span class="n">x_value</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="n">knn</span> <span class="o">=</span> <span class="n">NearestNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="mi">100</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">_</span><span class="p">,</span> <span class="n">ind_x_value</span> <span class="o">=</span> <span class="n">knn</span><span class="p">.</span><span class="n">kneighbors</span><span class="p">([[</span><span class="n">x_value</span><span class="p">]])</span>

<span class="n">kde</span> <span class="o">=</span> <span class="n">KernelDensity</span><span class="p">(</span><span class="n">kernel</span><span class="o">=</span><span class="s">"gaussian"</span><span class="p">,</span> <span class="n">bandwidth</span><span class="o">=</span><span class="s">"scott"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span>
    <span class="n">y</span><span class="p">[</span><span class="n">ind_x_value</span><span class="p">].</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">y_grid_refined</span><span class="p">,</span>
    <span class="n">norm</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mean_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="n">deviation_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">)).</span><span class="n">pdf</span><span class="p">(</span>
        <span class="n">y_grid_refined</span>
    <span class="p">),</span>
    <span class="s">"--"</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"real $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">y</span><span class="p">[</span><span class="n">ind_x_value</span><span class="p">].</span><span class="n">ravel</span><span class="p">(),</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="n">y_grid</span><span class="p">,</span>
    <span class="n">density</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"estimated $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$ using a histogram of nearest neighbors"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">y_grid_refined</span><span class="p">,</span>
    <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">kde</span><span class="p">.</span><span class="n">score_samples</span><span class="p">(</span><span class="n">y_grid_refined</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
    <span class="n">color</span><span class="o">=</span><span class="s">"C0"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"estimated $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$ using a kde with nearest neighbors"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Conditional density of y given X=x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_14_0.png" /></center></div></p>

<p><div align="justify">Notice that this method provides a much smoother estimate compared to the histogram.</div></p>

<p><div align="justify">We can encapsulate this logic within a class that tries to follow the <a href="https://scikit-learn.org/stable/developers/develop.html">scikit-learn API</a>, so that the <code>.predict</code> method applies the aforementioned logic for each requested value. In other words, it initially searches for the neighbors, and then employs a KDE to obtain the estimates for each example.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.base</span> <span class="kn">import</span> <span class="n">BaseEstimator</span><span class="p">,</span> <span class="n">clone</span>

<span class="k">class</span> <span class="nc">ConditionalNearestNeighborsKDE</span><span class="p">(</span><span class="n">BaseEstimator</span><span class="p">):</span>
    <span class="s">"""Conditional Kernel Density Estimation using nearest neighbors.

    This class implements a Conditional Kernel Density Estimation by applying
    the Kernel Density Estimation algorithm after a nearest neighbors search.

    It allows the use of user-specified nearest neighbor and kernel density
    estimators or, if not provided, defaults will be used.

    Parameters
    ----------
    nn_estimator : NearestNeighbors instance, default=None
        A pre-configured instance of a `~sklearn.neighbors.NearestNeighbors` class
        to use for finding nearest neighbors. If not specified, a
        `~sklearn.neighbors.NearestNeighbors` instance with `n_neighbors=100`
        will be used.

    kde_estimator : KernelDensity instance, default=None
        A pre-configured instance of a `~sklearn.neighbors.KernelDensity` class
        to use for estimating the kernel density. If not specified, a
        `~sklearn.neighbors.KernelDensity` instance with `bandwidth="scott"`
        will be used.
    """</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nn_estimator</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">kde_estimator</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator</span> <span class="o">=</span> <span class="n">nn_estimator</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">kde_estimator</span> <span class="o">=</span> <span class="n">kde_estimator</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span> <span class="o">=</span> <span class="n">NearestNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span> <span class="o">=</span> <span class="n">clone</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">y_train_</span> <span class="o">=</span> <span class="n">y</span>
        <span class="k">return</span> <span class="bp">self</span>

    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="s">"""Predict the conditional density estimation of new samples.

        The predicted density of the target for each sample in X is returned.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Vector to be estimated, where `n_samples` is the number of samples
            and `n_features` is the number of features.

        Returns
        -------
        kernel_density_list : list of len n_samples of KernelDensity instances
            Estimated conditional density estimations in the form of
            `~sklearn.neighbors.KernelDensity` instances.
        """</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">ind_X</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span><span class="p">.</span><span class="n">kneighbors</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">kde_estimator</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">kernel_density_list</span> <span class="o">=</span> <span class="p">[</span>
                <span class="n">KernelDensity</span><span class="p">(</span><span class="n">bandwidth</span><span class="o">=</span><span class="s">"scott"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">y_train_</span><span class="p">[</span><span class="n">ind</span><span class="p">].</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
                <span class="k">for</span> <span class="n">ind</span> <span class="ow">in</span> <span class="n">ind_X</span>
            <span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">kernel_density_list</span> <span class="o">=</span> <span class="p">[</span>
                <span class="n">clone</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">kde_estimator</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">y_train_</span><span class="p">[</span><span class="n">ind</span><span class="p">].</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
                <span class="k">for</span> <span class="n">ind</span> <span class="ow">in</span> <span class="n">ind_X</span>
            <span class="p">]</span>
        <span class="k">return</span> <span class="n">kernel_density_list</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
    <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.33</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span>
<span class="p">)</span>

<span class="n">ckde</span> <span class="o">=</span> <span class="n">ConditionalNearestNeighborsKDE</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">ckde_preds</span> <span class="o">=</span> <span class="n">ckde</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</code></pre></div></div>
<hr />

<h2 id="evaluation-metrics-for-conditional-density-estimation-methods">Evaluation metrics for conditional density estimation methods</h2>

<p><div align="justify">Clearly, applying traditional regression metrics directly here can be challenging, necessitating an approach specific to the problem we're addressing. This discussion is a bit more involved, but it's critical for evaluating our estimators.</div></p>

<p><div align="justify">$\oint$ <em>Certain metrics from conformal prediction could be utilized here, like "how often the observed target falls within a confidence interval", if you construct confidence intervals from the estimated conditional densities. However, metrics inherently suited to the nature of the problem are more suitable.</em></div></p>

<p><div align="justify">Let's denote the true conditional probability density of the problem as $f(y \,|\, X=x)$, and the estimated density as $\hat{f}(y \,|\, X=x)$. We want to gauge how close these two functions are, even though we don't have $f(y | x)$. A smart way to handle this is to compute the mean square error of the difference [<a href="#bibliography">3</a>]</div></p>

\[L(f, \hat{f}) = \mathbb{E}_X\left( \int \left( \hat{f}(y \,|\, X) - f(y \,|\, X) \right)^2 dy \right) = \int \int \left( \hat{f}(y \,|\, X=x) - f(y \,|\, X=x) \right)^2  dy \, f(x) \, dx.\]

<p><div align="justify">$\oint$ <em>This metric differs somewhat from the mean square error as empirical risk for our point estimates $h(x)$. When we calculate $\frac{1}{n} \sum_{i=1}^n \left( h(x_i) - y_i \right)^2$, we're effectively estimating.</em></div></p>

\[\mathbb{E}_{(X, Y)}\left( (h(X) - Y)^2 \right) = \int \int \left( h(x) - y \right)^2 f(x,y) \, dx \, dy.\]

<p><div align="justify"><em>In the metric $L$, we average only with respect to $X$, so that, for a fixed $X=x$, we want $\hat{f}(y \,|\, X=x)$ to approximate $f(y \,|\, X=x)$ well for all possible $y$ values uniformly in $\mathbb{R}$.</em></div></p>

<p><div align="justify">Upon expanding $L$, we obtain</div></p>

\[L(f, \hat{f}) = \int \int \left( \hat{f}(y \,|\, X=x) \right)^2 f(x) \, dy \, dx + \int \int -2\hat{f}(y \,|\, X=x) f(y, x) dx dy + C,\]

<p><div align="justify">where $f(x,y) = f(y \,|\, x) f(x)$ and $C$ is defined as $C = \int \int \left( f(y \,|\, x) \right)^2 f(x) \, dy \, dx$. As $C$ is a constant independent of the estimation method of $\hat{f}$, it can be disregarded when comparing models.</div></p>

<p><div align="justify">The first term can be written as</div></p>

\[\int \left( \int \left( \hat{f}(y \,|\, X=x) \right)^2\, dy \right) f(x)  \, dx,\]

<p><div align="justify">and the interior integral can be calculated using a numerical integration method while the x-integral can be estimated using an empirical average in a validation sample $S=(x_i, y_i)_{i=1}^n$. Specifically, we have</div></p>

\[\frac{1}{n} \sum_{i=1}^n \left( \int  \left( \hat{f}(y \,|\, X=x_i) \right)^2 \, dy \right).\]

<p><div align="justify">The second term can be directly estimated as the empirical average</div></p>

\[\frac{-2}{n} \sum_{i=1}^n \hat{f}(y_i \,|\, X=x_i),\]

<p><div align="justify">also using $S$.</div></p>

<p><div align="justify">Our estimates enable us to calculate a model comparison metric given by</div></p>

\[L(f, \hat{f}) \approx \hat{L}(f, \hat{f}, S) = \frac{1}{n} \sum_{i=1}^n \left( \int  \left( \hat{f}(y \,|\, X=x_i) \right)^2 \, dy \right) - \frac{2}{n} \sum_{i=1}^n \hat{f}(y_i \,|\, X=x_i),\]

<p><div align="justify">where a good model should yield as small a value as possible [<a href="#bibliography">3</a>].</div></p>

<p><div align="justify">$\oint$ <em>It's intriguing to casually interpret this final expression we obtained. As we aim to minimize $\hat{L}(f, \hat{f}, S)$, we want the integrals of the conditional densities squared to be small, while the likelihoods of the observed samples are large (rendering the second term highly negative). That is, we desire our function to be well-behaved and not explode, while we want the observed points to have a high likelihood of being sampled according to our prediction.</em></div></p>

<p><div align="justify">We can implement this in a way that it accepts pre-calculated density estimates and performs the necessary operations (both integration and summation). For the integral, we're explicitly asking for a <code>y_grid</code> where it will be estimated using <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.auc.html"><code>sklearn.metrics.auc</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">auc</span><span class="p">,</span> <span class="n">make_scorer</span>
<span class="kn">from</span> <span class="nn">joblib</span> <span class="kn">import</span> <span class="n">Parallel</span><span class="p">,</span> <span class="n">delayed</span>

<span class="k">def</span> <span class="nf">squared_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">cde_preds</span><span class="p">,</span> <span class="n">y_grid</span><span class="p">,</span> <span class="n">n_jobs</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
    <span class="s">"""
    Average squared loss between the true conditional density and predicted one.

    This method can be used to assess the quality of the conditional probability
    density function fit.

    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        The true values of y for each sample.

    cde_preds : list of len n_samples of KernelDensity instances
        The predicted conditional densitys. Each instance should be a fitted
        KernelDensity instance.

    y_grid : array-like of shape (n_samples,)
        The grid of y values used for computing the area under the curve (AUC)
        for the squared probability density function.

    n_jobs : int, optional
        The number of jobs to run in parallel. '-1' means using all processors.

    Returns
    -------
    average_squared_loss: float
        The average squared loss between the true and predicted conditional
        probability density functions. Note that it is always off by C.
    """</span>

    <span class="k">def</span> <span class="nf">_compute_individual_loss</span><span class="p">(</span><span class="n">y_</span><span class="p">,</span> <span class="n">cde_pred</span><span class="p">):</span>
        <span class="c1"># The score_samples and score methods returns stuff on log scale,
</span>        <span class="c1"># so we have to exp it.
</span>        <span class="n">squared_auc</span> <span class="o">=</span> <span class="n">auc</span><span class="p">(</span>
            <span class="n">y_grid</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">cde_pred</span><span class="p">.</span><span class="n">score_samples</span><span class="p">(</span><span class="n">y_grid</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span> <span class="o">**</span> <span class="mi">2</span>
        <span class="p">)</span>
        <span class="n">expected_value</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">cde_pred</span><span class="p">.</span><span class="n">score</span><span class="p">([[</span><span class="n">y_</span><span class="p">]]))</span>
        <span class="k">return</span> <span class="n">squared_auc</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">expected_value</span>

    <span class="n">individual_squared_loss</span> <span class="o">=</span> <span class="n">Parallel</span><span class="p">(</span><span class="n">n_jobs</span><span class="o">=</span><span class="n">n_jobs</span><span class="p">)(</span>
        <span class="n">delayed</span><span class="p">(</span><span class="n">_compute_individual_loss</span><span class="p">)(</span><span class="n">y_</span><span class="p">,</span> <span class="n">cde_pred</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">y_</span><span class="p">,</span> <span class="n">cde_pred</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">cde_preds</span><span class="p">)</span>
    <span class="p">)</span>

    <span class="n">average_squared_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">individual_squared_loss</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">average_squared_loss</span>
</code></pre></div></div>

<p><div align="justify">Applying this to the previous data provides us with a method to quantify our performance in CDE.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">squared_loss</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">ckde_preds</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.837595643080642
</code></pre></div></div>

<p><div align="justify">For the sake of comparison, we could contrast it with the density estimation of $Y$ without considering conditionality, that is, training the KDE on all the training data.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">squared_loss</span><span class="p">(</span>
    <span class="n">y_test</span><span class="p">,</span>
    <span class="nb">len</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span> <span class="o">*</span> <span class="p">[</span><span class="n">KernelDensity</span><span class="p">(</span><span class="n">bandwidth</span><span class="o">=</span><span class="s">"scott"</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">y_train</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))],</span>
    <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">),</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.38725117712967094
</code></pre></div></div>

<p><div align="justify">Since the previous value is lower, we can conclude that it provides a better density estimation, as anticipated.</div></p>

<p><div align="justify">$\oint$ <em>While this metric is useful for comparing models, it might be difficult to interpret from a business perspective. In this case, it might be helpful to convert your distribution forecast into a point forecast to calculate a more traditional metric, such as <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html"><code>sklearn.metrics.mean_absolute_error</code></a> (or even conformal prediction metrics), to provide a more digestible interpretation.</em></div></p>

<p><div align="justify">With a method to compare models in place, it's natural to want to optimize hyperparameters using a tool like <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a>. Given that we've designed the <code>ConditionalNearestNeighborsKDE</code> to comply with the <a href="https://scikit-learn.org/stable/developers/develop.html">scikit-learn standard</a>, and the metric in a way that it accepts the output from a <code>.predict</code> method, we can readily employ <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a> to optimize our usage of <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html"><code>sklearn.neighbors.NearestNeighbors</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span>

<span class="n">squared_loss_score</span> <span class="o">=</span> <span class="n">make_scorer</span><span class="p">(</span>
    <span class="n">partial</span><span class="p">(</span><span class="n">squared_loss</span><span class="p">,</span> <span class="n">y_grid</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)),</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">False</span>
<span class="p">)</span>
<span class="n">param_grid</span> <span class="o">=</span> <span class="p">{</span>
    <span class="s">"nn_estimator"</span><span class="p">:</span> <span class="p">[</span>
        <span class="n">NearestNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="n">n_neighbors</span><span class="p">)</span> <span class="k">for</span> <span class="n">n_neighbors</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">100</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">1000</span><span class="p">]</span>
    <span class="p">],</span>
<span class="p">}</span>
<span class="n">gs</span> <span class="o">=</span> <span class="n">GridSearchCV</span><span class="p">(</span>
    <span class="n">ConditionalNearestNeighborsKDE</span><span class="p">(),</span> <span class="n">param_grid</span><span class="o">=</span><span class="n">param_grid</span><span class="p">,</span> <span class="n">scoring</span><span class="o">=</span><span class="n">squared_loss_score</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">y_train</span><span class="p">)</span>

<span class="p">(</span>
    <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">gs</span><span class="p">.</span><span class="n">cv_results_</span><span class="p">)</span>
    <span class="p">.</span><span class="nb">filter</span><span class="p">(</span>
        <span class="p">[</span><span class="s">"param_nn_estimator"</span><span class="p">,</span> <span class="s">"mean_score_time"</span><span class="p">,</span> <span class="s">"mean_test_score"</span><span class="p">,</span> <span class="s">"std_test_score"</span><span class="p">]</span>
    <span class="p">)</span>
    <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">"mean_test_score"</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
    <span class="p">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_nn_estimator</th>
      <th>mean_score_time</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>NearestNeighbors(n_neighbors=500)</td>
      <td>3.838123</td>
      <td>0.890078</td>
      <td>0.021875</td>
    </tr>
    <tr>
      <th>1</th>
      <td>NearestNeighbors(n_neighbors=100)</td>
      <td>1.243606</td>
      <td>0.859500</td>
      <td>0.016847</td>
    </tr>
    <tr>
      <th>2</th>
      <td>NearestNeighbors(n_neighbors=1000)</td>
      <td>6.416354</td>
      <td>0.711722</td>
      <td>0.018199</td>
    </tr>
  </tbody>
</table>
</div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">squared_loss</span><span class="p">(</span>
    <span class="n">y_test</span><span class="p">,</span> <span class="n">gs</span><span class="p">.</span><span class="n">best_estimator_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.9058110146877884
</code></pre></div></div>

<p><div align="justify">In this case, we achieve a better score than before using any value of neighbors. However, we could still be interested in aspects of kernel estimation, which could further enhance the result.</div></p>

<p><div align="justify">The <code>ConditionalNearestNeighborsKDE</code> structure was proposed as it is more intuitive. Nonetheless, in higher dimensions or in scenarios with a lot of data, the neighbor search can encounter certain issues. Firstly, it's computationally costly due to the requirement for distance comparisons. Secondly, we might easily be at the mercy of varying scales of variables, potentially including categorical variables. Thirdly, we might have many less informative variables in $X$ and consequently suffer from the <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a>, with our neighbors becoming increasingly distant and less representative. In a real-world problem, you might have hundreds of covariates you wish to condition on and millions of examples, making this strategy possibly less suitable.</div></p>

<hr />

<h2 id="leafneighbors">LeafNeighbors</h2>

<p><div align="justify">A potential way to bypass the complications posed by neighbor searches in high dimensions, such as irrelevant variables, and varied scales and types is to formulate a more suitable distance metric robust to these challenges.</div></p>

<p><div align="justify">The manner in which tree training is conducted naturally equips it to tackle these issues effectively because: tree models learn what the important features are through the process of choosing the best splits; and they are not concerned with the scale of variables, as they focus only on the ordering during training.</div></p>

<p><div align="justify">When training a bagging of trees, we see variability in splits across the feature space, enabling us to use co-occurrence in the same leaves as a measure of similarity between examples [<a href="#bibliography">4</a>].</div></p>
<p><div align="justify">Hence, if we train a bagging model of regression trees like <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html"><code>sklearn.ensemble.RandomForestRegressor</code></a> or <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesRegressor.html"><code>sklearn.ensemble.ExtraTreesRegressor</code></a> to predict $Y$ from $X$, we are inherently constructing trees that create splits in relevant variables for predicting $Y$. At the same time, we disregard different scales by considering all instances that occur in the same leaf as similar, achieved by counting the co-occurrences of leaves across different models in the bagging [<a href="#bibliography">5</a>].</div></p>

<p><div align="justify">We can design a neighbor search class following this rationale, in accordance with the <a href="https://scikit-learn.org/stable/developers/develop.html">scikit-learn standards</a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.neighbors._base</span> <span class="kn">import</span> <span class="n">NeighborsBase</span>
<span class="kn">from</span> <span class="nn">sklearn.ensemble</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span>

<span class="k">class</span> <span class="nc">LeafNeighbors</span><span class="p">(</span><span class="n">NeighborsBase</span><span class="p">):</span>
    <span class="s">"""Neighbors search using leaf nodes coincidence in a tree ensemble as a
    similarity measure.

    This class implements a supervised neighbor search using the leaves of an
    ensemble tree estimator as a measure of distance. Examples that occur
    simultaneously in several leaves are naturally close in variables relevant
    to the target.

    Parameters
    ----------
    tree_ensemble_estimator : ForestRegressor instance, default=None
        The ensemble tree estimator to use. If None, a
        `~sklearn.ensemble.RandomForestRegressor` with `max_depth=10` will be
        used.

    n_neighbors : int, default=5
        Number of neighbors to use in the neighbor-based learning method.

    random_state : int, RandomState instance or None, default=None
        Controls the randomness of the ensemble tree estimator. Pass an int
        for reproducible output across multiple function calls.
    """</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tree_ensemble_estimator</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">n_neighbors</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span> <span class="o">=</span> <span class="n">tree_ensemble_estimator</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_neighbors</span> <span class="o">=</span> <span class="n">n_neighbors</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">random_state</span> <span class="o">=</span> <span class="n">random_state</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span>
                <span class="n">max_depth</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span>
            <span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span> <span class="o">=</span> <span class="n">clone</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span> <span class="o">=</span> <span class="n">NearestNeighbors</span><span class="p">(</span>
            <span class="n">n_neighbors</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">n_neighbors</span><span class="p">,</span> <span class="n">metric</span><span class="o">=</span><span class="s">"hamming"</span>
        <span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
        <span class="n">leafs_X</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">leafs_X</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span>

    <span class="k">def</span> <span class="nf">kneighbors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="n">leafs_X</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span><span class="p">.</span><span class="n">kneighbors</span><span class="p">(</span><span class="n">leafs_X</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">radius_neighbors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="n">leafs_X</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">tree_ensemble_estimator</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">nn_estimator_</span><span class="p">.</span><span class="n">radius_neighbors</span><span class="p">(</span><span class="n">leafs_X</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">And use it in the <code>ConditionalNearestNeighborsKDE</code>, defining the parameter <code>nn_estimator</code> with the custom search method.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">crfkde</span> <span class="o">=</span> <span class="n">ConditionalNearestNeighborsKDE</span><span class="p">(</span>
    <span class="n">nn_estimator</span><span class="o">=</span><span class="n">LeafNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">crfkde_preds</span> <span class="o">=</span> <span class="n">crfkde</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>

<span class="n">squared_loss</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">crfkde_preds</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.8328820889563571
</code></pre></div></div>

<p><div align="justify">In this scenario, the metric ended up similar to the one employed in the previous problem with neighbors, as the dimensionality is low. Consequently, the neighbors identified along the line closely align with the conventional approach of searching for nearby neighbors with Euclidean distance.</div></p>

<hr />

<h2 id="flexcode">FlexCode</h2>

<p><div align="justify">FlexCode takes a fundamentally different approach to the CDE problem by employing arguments from linear algebra to estimate the conditional probability density function using a function basis.</div></p>

<p><div align="justify">The space of <a href="https://mathworld.wolfram.com/L2-Space.html">square integrable functions</a> ($L^2(\mathbb{R})$) is a vector space equipped with an inner product defined as $\left\langle g, h\right\rangle = \int_{\mathbb{R}} g(t)\, h(t) \, dt$. Similar to finite-dimensional vector spaces, it possesses a (in this case, countably infinite) basis $\{ \phi_i \in L^2(\mathbb{R}) : i \in \mathbb{N}\}$, where any function $g \in L^2(\mathbb{R})$ can be expressed as a linear combination of the basis elements: $g(t) = \sum_{i=1}^\infty \beta_i \phi_i(t)$, for all $t \in \mathbb{R}$. Furthermore, it is possible to impose an orthonormal condition on the basis, such that $\left\langle \phi_i, \phi_j\right\rangle = \delta_{i,j}$, where $\delta_{i,j}$ equals $1$ if $i = j$ and $0$ otherwise [<a href="#bibliography">3</a>]. To help illustrate this concept, if you are unfamiliar with it, consider the analogy to the application of <a href="https://en.wikipedia.org/wiki/Fourier_series">Fourier series</a>.</div></p>

<p><div align="justify">With any fixed orthonormal basis $\{ \phi_i \}$, it is possible to express the conditional probability density function as follows [<a href="#bibliography">3</a>]:</div></p>

\[f(y \,|\, X=x) = \sum_{i=1}^\infty \beta_i(x)\, \phi_i(y),\]

<p><div align="justify">In this formulation, we explicitly incorporate the dependence of $X=x$ within the coefficients of the linear combination.</div></p>

<p><div align="justify">It is worth noting that due to the orthonormality of the basis $\{ \phi_i \}$, we have that</div></p>

\[\begin{align*}
    \mathbb{E}\left( \phi_j(Y) \,|\, X=x \right) &amp;= \int_\mathbb{R} \phi_j(y) \,f(y \,|\, X=x) \,dy\\
    &amp;= \int_\mathbb{R} \phi_j(y) \sum_{i=1}^\infty \beta_i(x)\, \phi_i(y) \,dy\\
    &amp;= \sum_{i=1}^\infty \beta_i(x) \int_\mathbb{R} \phi_j(y) \, \phi_i(y) \,dy\\
    &amp;= \sum_{i=1}^\infty \beta_i(x) \,\delta_{i,j} = \beta_j(x).
\end{align*}\]

<p><div align="justify">Hence, the estimation of $\hat{\beta}_j(x)$ can be achieved through regression, utilizing $X$ as predictors to estimate $\phi_j(Y)$. Note that it is possible to interchange the summation and integration due to <a href="https://en.wikipedia.org/wiki/Fubini%27s_theorem">Fubini's Theorem</a>.</div></p>

<p><div align="justify">The <a href="https://pypi.org/project/flexcode/">FlexCode</a> algorithm [<a href="#bibliography">3</a>] adopts this approach. By employing a designated <code>basis_system</code> (a hyperparameter of the model), the algorithm estimates the coefficients using regressions of $\phi_j(Y)$. Since computing the infinite sum is not practical, it is truncated at a specified value, <code>max_basis</code> denoted as $I$ (which can be determined through cross-validation as a hyperparameter). Consequently, we obtain that</div></p>

\[\hat{f}(y \,|\, X=x) = \sum_{i=1}^I \hat{\beta}_i(x) \, \phi_i(y).\]

<hr />

<h2 id="using-flexcode-in-python">Using FlexCode in Python</h2>

<p><div align="justify">To utilize FlexCode, we first need to define the regression model along with its parameters, as well as the previously mentioned hyperparameters.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">flexcode.regression_models</span> <span class="kn">import</span> <span class="n">RandomForest</span>
<span class="kn">from</span> <span class="nn">flexcode</span> <span class="kn">import</span> <span class="n">FlexCodeModel</span>

<span class="n">flexcode_model</span> <span class="o">=</span> <span class="n">FlexCodeModel</span><span class="p">(</span>
    <span class="n">RandomForest</span><span class="p">,</span>
    <span class="n">basis_system</span><span class="o">=</span><span class="s">"cosine"</span><span class="p">,</span>
    <span class="n">max_basis</span><span class="o">=</span><span class="mi">31</span><span class="p">,</span>
    <span class="n">regression_params</span><span class="o">=</span><span class="p">{</span><span class="s">"max_depth"</span><span class="p">:</span> <span class="mi">5</span><span class="p">,</span> <span class="s">"n_estimators"</span><span class="p">:</span> <span class="mi">100</span><span class="p">},</span>
<span class="p">)</span>
<span class="n">flexcode_model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">As implemented, the estimator returns the value of $\hat{f}(y \,|\, X=x)$ on a grid of $y$ values.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cdes</span><span class="p">,</span> <span class="n">y_grid_flexcode</span> <span class="o">=</span> <span class="n">flexcode_model</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">n_grid</span><span class="o">=</span><span class="mi">400</span><span class="p">)</span>
<span class="n">y_grid_flexcode</span> <span class="o">=</span> <span class="n">y_grid_flexcode</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="k">for</span> <span class="n">c</span><span class="p">,</span> <span class="n">sample_index</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="mi">13</span><span class="p">).</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X_test</span><span class="p">),</span> <span class="n">size</span><span class="o">=</span><span class="mi">3</span><span class="p">)):</span>
    <span class="n">x_value</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">X_test</span><span class="p">[</span><span class="n">sample_index</span><span class="p">],</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
        <span class="n">y_grid_refined</span><span class="p">,</span>
        <span class="n">norm</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mean_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="n">deviation_function</span><span class="p">(</span><span class="n">x_value</span><span class="p">)).</span><span class="n">pdf</span><span class="p">(</span>
            <span class="n">y_grid_refined</span>
        <span class="p">),</span>
        <span class="s">"--"</span><span class="p">,</span>
        <span class="n">color</span><span class="o">=</span><span class="sa">f</span><span class="s">"C</span><span class="si">{</span><span class="n">c</span><span class="si">}</span><span class="s">"</span><span class="p">,</span>
        <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"real $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$"</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
        <span class="n">y_grid_flexcode</span><span class="p">,</span>
        <span class="n">cdes</span><span class="p">[</span><span class="n">sample_index</span><span class="p">],</span>
        <span class="n">color</span><span class="o">=</span><span class="sa">f</span><span class="s">"C</span><span class="si">{</span><span class="n">c</span><span class="si">}</span><span class="s">"</span><span class="p">,</span>
        <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">"estimated $f(y | x = </span><span class="si">{</span><span class="n">x_value</span><span class="si">}</span><span class="s">)$ using flexcode"</span><span class="p">,</span>
    <span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Conditional density of y given X=x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/conditional_density_estimation/output_40_0.png" /></center></div></p>

<p><div align="justify">To evaluate the estimator, considering that we constructed our metric to work with an object similar to <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KernelDensity.html"><code>sklearn.neighbors.KernelDensity</code></a>, we need to ensure it has specific methods that we can implement, adapting the output of <a href="https://github.com/lee-group-cmu/FlexCode"><code>flexcode.FlexCodeModel</code></a> to match this format.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">FlexCode_return_to_DensityEstimator</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_grid</span><span class="p">,</span> <span class="n">pdf_values</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">y_grid</span> <span class="o">=</span> <span class="n">y_grid</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">pdf_values</span> <span class="o">=</span> <span class="n">pdf_values</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">density</span> <span class="o">=</span> <span class="n">rv_histogram</span><span class="p">(</span>
            <span class="p">(</span><span class="n">pdf_values</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">y_grid</span><span class="p">,</span> <span class="p">[</span><span class="n">y_grid</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">y_grid</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">y_grid</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]]]))</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">score_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">density</span><span class="p">.</span><span class="n">pdf</span><span class="p">(</span><span class="n">X</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">score</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">score_samples</span><span class="p">(</span><span class="n">X</span><span class="p">))</span>

<span class="n">density_estimation_preds_flexcode</span> <span class="o">=</span> <span class="p">[</span>
    <span class="n">FlexCode_return_to_DensityEstimator</span><span class="p">(</span><span class="n">y_grid</span><span class="o">=</span><span class="n">y_grid_flexcode</span><span class="p">,</span> <span class="n">pdf_values</span><span class="o">=</span><span class="n">cde</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">cde</span> <span class="ow">in</span> <span class="n">cdes</span>
<span class="p">]</span>
<span class="n">squared_loss</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">density_estimation_preds_flexcode</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1000</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-1.5436164449474372
</code></pre></div></div>

<p><div align="justify">In this scenario, the metric we obtained outperforms the KDE based on nearest neighbors that we used previously.</div></p>

<hr />

<h2 id="practical-application">Practical application</h2>

<p><div align="justify">Let's apply these various techniques to a real regression problem, namely <a href="https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html"><code>sklearn.datasets.fetch_california_housing</code></a>, to evaluate the performance of the different approaches discussed.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">fetch_california_housing</span>

<span class="n">X_california</span><span class="p">,</span> <span class="n">y_california</span> <span class="o">=</span> <span class="n">fetch_california_housing</span><span class="p">(</span><span class="n">return_X_y</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">(</span>
    <span class="n">X_california_train</span><span class="p">,</span>
    <span class="n">X_california_test</span><span class="p">,</span>
    <span class="n">y_california_train</span><span class="p">,</span>
    <span class="n">y_california_test</span><span class="p">,</span>
<span class="p">)</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X_california</span><span class="p">,</span> <span class="n">y_california</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.33</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"X dimension: </span><span class="si">{</span><span class="n">X_california</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>X dimension: 8
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">ckde_california</span> <span class="o">=</span> <span class="n">ConditionalNearestNeighborsKDE</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span>
    <span class="n">X_california_train</span><span class="p">,</span> <span class="n">y_california_train</span>
<span class="p">)</span>
<span class="n">ckde_california_preds</span> <span class="o">=</span> <span class="n">ckde_california</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_california_test</span><span class="p">)</span>

<span class="n">squared_loss</span><span class="p">(</span><span class="n">y_california_test</span><span class="p">,</span> <span class="n">ckde_california_preds</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">1000</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.2948159711962537
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">crfkde_california</span> <span class="o">=</span> <span class="n">ConditionalNearestNeighborsKDE</span><span class="p">(</span>
    <span class="n">nn_estimator</span><span class="o">=</span><span class="n">LeafNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_california_train</span><span class="p">,</span> <span class="n">y_california_train</span><span class="p">)</span>
<span class="n">crfkde_california_preds</span> <span class="o">=</span> <span class="n">crfkde_california</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_california_test</span><span class="p">)</span>

<span class="n">squared_loss</span><span class="p">(</span><span class="n">y_california_test</span><span class="p">,</span> <span class="n">crfkde_california_preds</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">1000</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-0.6802084650191235
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model_california</span> <span class="o">=</span> <span class="n">FlexCodeModel</span><span class="p">(</span>
    <span class="n">RandomForest</span><span class="p">,</span> <span class="n">max_basis</span><span class="o">=</span><span class="mi">31</span><span class="p">,</span> <span class="n">regression_params</span><span class="o">=</span><span class="p">{</span><span class="s">"max_depth"</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span> <span class="s">"n_estimators"</span><span class="p">:</span> <span class="mi">100</span><span class="p">}</span>
<span class="p">)</span>
<span class="n">model_california</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_california_train</span><span class="p">,</span> <span class="n">y_california_train</span><span class="p">)</span>

<span class="n">cdes_california</span><span class="p">,</span> <span class="n">y_grid_california</span> <span class="o">=</span> <span class="n">model_california</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span>
    <span class="n">X_california_test</span><span class="p">,</span> <span class="n">n_grid</span><span class="o">=</span><span class="mi">400</span>
<span class="p">)</span>
<span class="n">y_grid_california</span> <span class="o">=</span> <span class="n">y_grid_california</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">density_estimation_preds_flexcode_california</span> <span class="o">=</span> <span class="p">[</span>
    <span class="n">FlexCode_return_to_DensityEstimator</span><span class="p">(</span><span class="n">y_grid</span><span class="o">=</span><span class="n">y_grid_california</span><span class="p">,</span> <span class="n">pdf_values</span><span class="o">=</span><span class="n">cde</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">cde</span> <span class="ow">in</span> <span class="n">cdes_california</span>
<span class="p">]</span>

<span class="n">squared_loss</span><span class="p">(</span>
    <span class="n">y_california_test</span><span class="p">,</span>
    <span class="n">density_estimation_preds_flexcode_california</span><span class="p">,</span>
    <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">1000</span><span class="p">),</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>-1.2739272081741533
</code></pre></div></div>

<p><div align="justify">We can observe that the neighbor search using <code>LeafNeighbors</code> outperforms the conventional neighbor search in our <code>ConditionalNearestNeighborsKDE</code> method. However, the <code>flexcode.FlexCodeModel</code> yields even better results compared to both methods in this example.</div></p>

<hr />

<h2 id="final-considerations">Final considerations</h2>

<p><div align="justify">Delving deeper into regression problems, beyond simple point estimates, can be challenging. However, this approach provides a wealth of insightful information that can enhance your decision-making process. While it's a significant area of study, it has not yet become a primary focus within the community. Nonetheless, I anticipate a surge of interest as more individuals realize its value.</div></p>

<p><div align="justify">Currently, the libraries designed to address these intricate problems are being refined, with issues being resolved over time. Therefore, when utilizing these tools, it's essential to exercise caution and promptly report any anomalous behavior observed.</div></p>

<p><div align="justify">$\oint$ <em>I wanted to mention the tree-based neighbor method because it is possible to adapt the tree training in a specific way for the CDE problems. Typically, in regression problems, a decision tree would aim to optimize a particular metric such as <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"><code>sklearn.metrics.mean_squared_error</code></a> when establishing its splits. However, for CDE problems, there's a possibility to optimize a CDE-specific metric directly within the splits. One such metric could be the CDE <code>squared_loss</code> we implemented earlier. This approach is what the <a href="https://github.com/lee-group-cmu/RFCDE">RFCDE (Random Forests for Conditional Density Estimation)</a> method suggests [<a href="#bibliography">6</a>].</em></div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p><div align="justify">[1] <a href="http://people.eecs.berkeley.edu/~angelopoulos/blog/posts/gentle-intro/">A Gentle Introduction to Conformal Prediction and Distribution-Free Uncertainty Quantification. Anastasios N. Angelopoulos, Stephen Bates. 2021.</a></div></p>

<p><div align="justify">[2] <a href="https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html">Python Data Science Handbook: In-Depth Kernel Density Estimation. Jake VanderPlas. 2016.</a></div></p>

<p><div align="justify">[3] <a href="https://projecteuclid.org/journals/electronic-journal-of-statistics/volume-11/issue-2/Converting-high-dimensional-regression-to-high-dimensional-conditional-density-estimation/10.1214/17-EJS1302.full">Converting high-dimensional regression to high-dimensional conditional density estimation. Rafael Izbicki, Ann B. Lee. Electron. J. Statist. 2017.</a></div></p>

<p><div align="justify">[4] <a href="https://gdmarmerola.github.io/forest-embeddings/">Supervised clustering and forest embeddings. Guilherme Duarte Marmerola. 2018.</a></div></p>

<p><div align="justify">[5] <a href="https://jmlr.csail.mit.edu/papers/volume7/meinshausen06a/meinshausen06a.pdf">Quantile Regression Forests. Nicolai Meinshausen. Journal of Machine Learning Research. 2006.</a></div></p>

<p><div align="justify">[6] <a href="https://arxiv.org/abs/1804.05753">RFCDE: Random Forests for Conditional Density Estimation. Taylor Pospisil, Ann B. Lee. 2018.</a></div></p>

<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/conditional_density_estimation">repository of this post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;uncertainty quantification&quot;]" /><summary type="html"><![CDATA[CDE is the process of estimating the probability density function of a random variable given the values of other variables.]]></summary></entry><entry><title type="html">Hyperparameter search with threshold-dependent metrics</title><link href="https://vitaliset.github.io/threshold-dependent-opt/" rel="alternate" type="text/html" title="Hyperparameter search with threshold-dependent metrics" /><published>2023-01-06T00:00:00+00:00</published><updated>2023-01-06T00:00:00+00:00</updated><id>https://vitaliset.github.io/threshold-dependent-opt</id><content type="html" xml:base="https://vitaliset.github.io/threshold-dependent-opt/"><![CDATA[<p><div align="justify">In a binary classification problem, you probably shouldn&#39;t ever use the <code>.predict</code> method from scikit-learn (and consequently from libraries that follow <a href="https://scikit-learn.org/stable/developers/develop.html">its design pattern</a>). In scikit-learn, the implementation of <code>.predict</code>, in general, follows the logic <a href="https://github.com/scikit-learn/scikit-learn/blob/98cf537f5c538fdbc9d27b851cf03ce7611b8a48/sklearn/ensemble/_forest.py#L800-L837">implemented</a> for <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a>:</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="p">...</span>
    <span class="n">proba</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
    <span class="p">...</span>
    <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">classes_</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">proba</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">In the case where we only have two classes (0 or 1), the <code>.predict</code>, when picking the class with the highest &quot;probability&quot;, is equivalent to the rule &quot; if <code>.predict_proba &gt; 0.5</code>, then predict <code>1</code>; otherwise, predict <code>0</code>&quot;. That is, under the hood, we are using a threshold of <code>0.5</code> without having visibility.</div></p>

<p><div align="justify">Up to now, nothing new. However, we will show in an example how this can be harmful to superficial analyses that don&#39;t take this into account.</div></p>

<hr />

<h2 id="optimizing-f1-in-a-naive-way">Optimizing f1 in a naive way</h2>

<p><div align="justify">To exemplify this issue, we will use a dataset from <a href="https://imbalanced-learn.org/stable/">imbalanced-learn</a>, a library with several implementations of techniques that deal with imbalanced problems, from the <a href="https://github.com/scikit-learn-contrib">scikit-learn-contrib</a> environment. So, let&#39;s build a model that ideally has the best possible <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">imblearn.datasets</span> <span class="kn">import</span> <span class="n">fetch_datasets</span>

<span class="n">dataset</span> <span class="o">=</span> <span class="n">fetch_datasets</span><span class="p">()[</span><span class="s">"coil_2000"</span><span class="p">]</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="n">dataset</span><span class="p">.</span><span class="n">target</span><span class="o">==</span><span class="mi">1</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Percentage of y=1 is </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">y</span><span class="p">.</span><span class="n">mean</span><span class="p">(),</span> <span class="mi">5</span><span class="p">)</span><span class="o">*</span><span class="mi">100</span><span class="si">}</span><span class="s">%."</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Number of rows is </span><span class="si">{</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s">."</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Percentage of y=1 is 5.966%.
Number of rows is 9822.
</code></pre></div></div>

<p><div align="justify">I&#39;m going to divide the dataset (taking care of the stratification because we are in an imbalanced problem) into a set for training the model, a second set for choosing the threshold, and a last one for validation. We will not be dealing with the second set for now, but we will show some ways of optimizing the threshold that will need this extra set.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">X_train_model</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">y</span><span class="p">)</span>
<span class="n">X_train_model</span><span class="p">,</span> <span class="n">X_train_threshold</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">,</span> <span class="n">y_train_threshold</span> <span class="o">=</span> \
<span class="n">train_test_split</span><span class="p">(</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">y_train_model</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Suppose we want to optimize the hyperparameters of a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a> getting the best possible <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> (as we anticipated just now).</div></p>

<p><div align="justify">I&#39;m going to create an auxiliary function to run this search for hyperparameters because we&#39;re going to do this several times (using a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a>, but it could be any other way to search for hyperparameters).</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">StratifiedKFold</span>

<span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
    <span class="s">"max_depth"</span><span class="p">:</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="bp">None</span><span class="p">],</span>
    <span class="s">"n_estimators"</span><span class="p">:</span> <span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">],</span>
<span class="p">}</span>

<span class="n">skfold</span> <span class="o">=</span> <span class="n">StratifiedKFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                         <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
                         <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span>
<span class="kn">from</span> <span class="nn">sklearn.ensemble</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span>

<span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">estimator</span><span class="p">,</span> <span class="n">scoring</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="p">):</span>
    <span class="n">gscv</span> <span class="o">=</span> <span class="p">(</span>
        <span class="n">GridSearchCV</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span>
                     <span class="n">param_grid</span><span class="o">=</span><span class="n">params</span><span class="p">,</span>
                     <span class="n">scoring</span><span class="o">=</span><span class="n">scoring</span><span class="p">,</span>
                     <span class="n">cv</span><span class="o">=</span><span class="n">cv</span><span class="p">)</span>
        <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="p">)</span>

    <span class="k">return</span> <span class="p">(</span>
        <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">gscv</span><span class="p">.</span><span class="n">cv_results_</span><span class="p">)</span>
        <span class="p">.</span><span class="n">pipe</span><span class="p">(</span><span class="k">lambda</span> <span class="n">df</span><span class="p">:</span>
              <span class="n">df</span><span class="p">[</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="s">"param_"</span> <span class="o">+</span> <span class="n">x</span><span class="p">,</span>  <span class="n">params</span><span class="p">.</span><span class="n">keys</span><span class="p">()))</span> <span class="o">+</span> <span class="p">[</span><span class="s">"mean_test_score"</span><span class="p">,</span> <span class="s">"std_test_score"</span><span class="p">]])</span>
    <span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">With this auxiliary function built, we run our search trying to optimize <code>scoring=&quot;f1&quot;</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">run_experiment</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
               <span class="n">scoring</span><span class="o">=</span><span class="s">"f1"</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train_model</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">skfold</span><span class="p">)</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_max_depth</th>
      <th>param_n_estimators</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>2</td>
      <td>10</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>1</th>
      <td>2</td>
      <td>50</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>2</th>
      <td>2</td>
      <td>100</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>3</th>
      <td>4</td>
      <td>10</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>4</th>
      <td>4</td>
      <td>50</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>5</th>
      <td>4</td>
      <td>100</td>
      <td>0.000000</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>6</th>
      <td>10</td>
      <td>10</td>
      <td>0.059510</td>
      <td>0.039552</td>
    </tr>
    <tr>
      <th>7</th>
      <td>10</td>
      <td>50</td>
      <td>0.040333</td>
      <td>0.016119</td>
    </tr>
    <tr>
      <th>8</th>
      <td>10</td>
      <td>100</td>
      <td>0.034938</td>
      <td>0.014265</td>
    </tr>
    <tr>
      <th>9</th>
      <td>None</td>
      <td>10</td>
      <td>0.097418</td>
      <td>0.007834</td>
    </tr>
    <tr>
      <th>10</th>
      <td>None</td>
      <td>50</td>
      <td>0.105050</td>
      <td>0.022298</td>
    </tr>
    <tr>
      <th>11</th>
      <td>None</td>
      <td>100</td>
      <td>0.096360</td>
      <td>0.016211</td>
    </tr>
  </tbody>
</table>
</div>

<p><div align="justify">Some combinations of hyperparameters seem to have an <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> of 0. Weird.</div></p>

<p><div align="justify">This happens because as <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> is a threshold-dependent metric (in the sense that it needs hard predictions instead of predicted probabilities), scikit-learn understands that it needs to use <code>.predict</code> instead of <code>.predict_proba</code> (and consequently &quot;uses the threshold of <code>0.5</code>&quot;, as we discussed the equivalence earlier).</div></p>

<p><div align="justify">As our problem is imbalanced, a threshold of <code>0.5</code> is usually suboptimal. And that&#39;s the case. We will have a considerable accumulation of <code>.predict_proba</code> close to 0 in almost any model, and, probably, a threshold closer to <code>0</code> in our problem seems more reasonable.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">Counter</span>
<span class="n">out_of_the_box_model</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">)</span>

<span class="n">predict_proba</span> <span class="o">=</span> <span class="n">out_of_the_box_model</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_train_threshold</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">predict</span> <span class="o">=</span> <span class="n">out_of_the_box_model</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_train_threshold</span><span class="p">)</span>

<span class="c1"># Just to check. ;)
</span><span class="k">assert</span> <span class="p">((</span><span class="n">predict_proba</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span> <span class="o">==</span> <span class="n">predict</span><span class="p">).</span><span class="nb">all</span><span class="p">()</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">))</span>

<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">hist</span><span class="p">(</span><span class="n">predict_proba</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">26</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Histogram of .predict_proba(X)"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>

<span class="n">count_predict</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span><span class="n">predict</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">bar</span><span class="p">(</span><span class="n">count_predict</span><span class="p">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">count_predict</span><span class="p">.</span><span class="n">values</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">".predict(X)"</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mf">0.4</span><span class="p">)</span>
<span class="n">count_y</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span><span class="n">y_train_threshold</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">bar</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">count_y</span><span class="p">.</span><span class="n">keys</span><span class="p">()))</span> <span class="o">+</span> <span class="mf">0.4</span><span class="p">,</span> <span class="n">count_y</span><span class="p">.</span><span class="n">values</span><span class="p">(),</span> <span class="n">label</span><span class="o">=</span><span class="s">"y"</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mf">0.4</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">])</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_xticklabels</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">tick_params</span><span class="p">(</span><span class="n">bottom</span> <span class="o">=</span> <span class="bp">False</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Count of 0's and 1's"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">legend</span><span class="p">(</span><span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_13_0.png" /></center></div></p>

<p><div align="justify">Very few examples pass the <code>0.5</code> threshold, significantly fewer than the actual number of class 1 samples. This tells us that a softer threshold (less than <code>0.5</code>) makes more sense in this problem.</div></p>

<p><div align="justify">This is often the case in imbalanced learning scenarios. For instance, if you have 1% of people with some disease in your population and your model predicts that this person has a 10% chance of having that disease, then chances are that you should treat him as someone with a high likelihood of being ill.</div></p>

<hr />

<h2 id="tuning-the-threshold">Tuning the threshold</h2>

<p><div align="justify">To find the optimal threshold, we can <a href="https://hastie.su.domains/ISLR2/ISLRv2_website.pdf">bootstrap</a> a set separate from the one used in training to find the best threshold for that model by optimizing some metric (threshold-dependent) of interest, such as, in our case, <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>

<span class="k">def</span> <span class="nf">optmize_threshold_metric</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="p">,</span> <span class="n">n_bootstrap</span><span class="o">=</span><span class="mi">20</span><span class="p">):</span>
    <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">threshold_grid</span><span class="p">):</span>
        <span class="n">metrics</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_bootstrap</span><span class="p">):</span>
            <span class="n">ind_bootstrap</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">i</span><span class="p">).</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="n">metric_val</span> <span class="o">=</span> <span class="n">metric</span><span class="p">(</span><span class="n">y</span><span class="p">[</span><span class="n">ind_bootstrap</span><span class="p">],</span>
                          <span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">ind_bootstrap</span><span class="p">])[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">t</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
            <span class="n">metrics</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">metric_val</span><span class="p">)</span>
        <span class="n">metric_means</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">metrics</span><span class="p">))</span>
        <span class="n">metric_stds</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">std</span><span class="p">(</span><span class="n">metrics</span><span class="p">))</span>

    <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">metric_means</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">metric_stds</span><span class="p">)</span>
    <span class="n">best_threshold</span> <span class="o">=</span> <span class="n">threshold_grid</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">metric_means</span><span class="p">)]</span>

    <span class="k">return</span> <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span><span class="p">,</span> <span class="n">best_threshold</span>
</code></pre></div></div>

<p><div align="justify">For each threshold value, we estimate the mean of <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> that we expect to obtain with that choice if we run the experiment multiple times through the bootstrap and the standard deviation to get an idea of the variance of the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> we got. We chose the final threshold as the one with the best-estimated <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">threshold_grid</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">f1_score</span>

<span class="n">f1_means_ootb</span><span class="p">,</span> <span class="n">f1_stds_ootb</span><span class="p">,</span> <span class="n">best_threshold_ootb</span> <span class="o">=</span> \
<span class="n">optmize_threshold_metric</span><span class="p">(</span><span class="n">out_of_the_box_model</span><span class="p">,</span> <span class="n">X_train_threshold</span><span class="p">,</span> <span class="n">y_train_threshold</span><span class="p">,</span> <span class="n">f1_score</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">threshold_grid</span><span class="p">,</span> <span class="n">f1_means_ootb</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">threshold_grid</span><span class="p">,</span> <span class="n">f1_means_ootb</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_ootb</span><span class="p">,</span> <span class="n">f1_means_ootb</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_ootb</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">vlines</span><span class="p">(</span><span class="n">best_threshold_ootb</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">f1_means_ootb</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_ootb</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">f1_means_ootb</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_ootb</span><span class="p">),</span> <span class="s">"k"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Chosen threshold"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Threshold"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"f1_score"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>100%|██████████| 101/101 [02:00&lt;00:00,  1.19s/it]
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_18_1.png" /></center></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">f1_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="p">(</span><span class="n">out_of_the_box_model</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">best_threshold_ootb</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.1878453038674033
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">f1_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">out_of_the_box_model</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.043478260869565216
</code></pre></div></div>

<p><div align="justify">With the threshold chosen through optimization, we ended up with a much better <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> than the one we get with <code>.predict</code>, with the <code>0.5</code> threshold.</div></p>

<p><div align="justify">$\oint$ <em>Here we are directly choosing the threshold that, on average, has the best metric value of interest, but there are other possibilities [<a href="#bibliography">1</a>]. We could, for example, play with the &quot;confidence interval&quot; (which, in this case, I&#39;m just plotting to give an order of magnitude of the variance), optimizing for the upper or lower limit, or even use the threshold that maximizes <a href="https://en.wikipedia.org/wiki/Youden%27s_J_statistic">Youden&#39;s J statistic</a> (which is equivalent to taking the threshold that gives the most significant separation of the KS curves between the <code>.predict_proba(X[y==0])</code> and <code>.predict_proba(X[y==1])</code>).</em></div></p>

<hr />

<h2 id="back-to-hyperparameters-search">Back to hyperparameters search</h2>

<p><div align="justify">But what to do now? How can we get around this if optimizing the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> directly doesn&#39;t look like a good idea since scikit-learn will use <code>.predict</code>? We will discuss three possibilities of how to get around this issue. One case is not necessarily better than the other, and the idea is to show some options for facing the problem.</div></p>

<h3 id="1-optimizing-a-metric-that-works-and-is-related-to-the-desired-metric">1. Optimizing a metric that works and is related to the desired metric</h3>

<p><div align="justify">The most common approach is, even if you are interested in the threshold-dependent metric, to use a threshold-independent metric to do this optimization and only, in the end, use something like <code>optmize_threshold_metric</code> to optimize the metric of genuine interest.</div></p>

<p><div align="justify">$\oint$ <em>This sounds sub-optimal, but we do this all the time in Machine Learning. Even if you&#39;re interested in optimizing <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a> on a credit default classification problem, your <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a> will be optimizing for <code>criterion=&quot;gini&quot;</code> or something related to <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a>, but that is different. Here the idea is the same. Optimizing for <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>sklearn.metrics.roc_auc_score</code></a> or <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html"><code>sklearn.metrics.average_precision_score</code></a> is not the same as optimizing for <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a>, for example, but models that are good at the former will be good at the latter too.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">run_experiment</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
               <span class="n">scoring</span><span class="o">=</span><span class="s">"roc_auc"</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train_model</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">skfold</span><span class="p">)</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_max_depth</th>
      <th>param_n_estimators</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>2</td>
      <td>10</td>
      <td>0.719377</td>
      <td>0.008165</td>
    </tr>
    <tr>
      <th>1</th>
      <td>2</td>
      <td>50</td>
      <td>0.746675</td>
      <td>0.007476</td>
    </tr>
    <tr>
      <th>2</th>
      <td>2</td>
      <td>100</td>
      <td>0.742196</td>
      <td>0.007105</td>
    </tr>
    <tr>
      <th>3</th>
      <td>4</td>
      <td>10</td>
      <td>0.733715</td>
      <td>0.013691</td>
    </tr>
    <tr>
      <th>4</th>
      <td>4</td>
      <td>50</td>
      <td>0.744482</td>
      <td>0.010491</td>
    </tr>
    <tr>
      <th>5</th>
      <td>4</td>
      <td>100</td>
      <td>0.747113</td>
      <td>0.007466</td>
    </tr>
    <tr>
      <th>6</th>
      <td>10</td>
      <td>10</td>
      <td>0.695511</td>
      <td>0.018646</td>
    </tr>
    <tr>
      <th>7</th>
      <td>10</td>
      <td>50</td>
      <td>0.703767</td>
      <td>0.019845</td>
    </tr>
    <tr>
      <th>8</th>
      <td>10</td>
      <td>100</td>
      <td>0.708600</td>
      <td>0.022674</td>
    </tr>
    <tr>
      <th>9</th>
      <td>None</td>
      <td>10</td>
      <td>0.652099</td>
      <td>0.031056</td>
    </tr>
    <tr>
      <th>10</th>
      <td>None</td>
      <td>50</td>
      <td>0.682542</td>
      <td>0.017131</td>
    </tr>
    <tr>
      <th>11</th>
      <td>None</td>
      <td>100</td>
      <td>0.685519</td>
      <td>0.020818</td>
    </tr>
  </tbody>
</table>
</div>

<h3 id="2-leak-the-threshold-search">2. Leak the threshold search</h3>

<p><div align="justify">But what if we want to explicitly optimize our interest metric within the grid search for some reason? In that case, we need to make a bigger workaround. A reasonable proxy of how your model will perform when you optimize the threshold is to optimize the threshold on your test set. In this case, as you will choose the threshold that will optimize the metric in the validation set, your metric will be the best possible, and you can directly take the <code>max</code> or the <code>min</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">make_scorer</span>

<span class="k">def</span> <span class="nf">make_threshold_independent</span><span class="p">(</span><span class="n">metric</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">),</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
    <span class="n">opt_fun</span> <span class="o">=</span> <span class="p">{</span><span class="bp">True</span><span class="p">:</span> <span class="nb">max</span><span class="p">,</span> <span class="bp">False</span><span class="p">:</span> <span class="nb">min</span><span class="p">}</span>
    <span class="n">opt</span> <span class="o">=</span> <span class="n">opt_fun</span><span class="p">[</span><span class="n">greater_is_better</span><span class="p">]</span>
    <span class="k">def</span> <span class="nf">threshold_independent_metric</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">opt</span><span class="p">([</span><span class="n">metric</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="p">(</span><span class="n">y_pred</span> <span class="o">&gt;</span> <span class="n">t</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">threshold_grid</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">threshold_independent_metric</span>

<span class="n">f1_threshold_independent_score</span> <span class="o">=</span> <span class="n">make_threshold_independent</span><span class="p">(</span><span class="n">f1_score</span><span class="p">)</span>
<span class="n">f1_threshold_independent_scorer</span> <span class="o">=</span> <span class="n">make_scorer</span><span class="p">(</span><span class="n">f1_threshold_independent_score</span><span class="p">,</span> <span class="n">needs_proba</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">As this is a threshold-independent metric (because we passed <code>needs_proba=True</code>), we will no longer have the problem of scikit-learn using <code>.predict</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df_best_f1</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
                            <span class="n">scoring</span><span class="o">=</span><span class="n">f1_threshold_independent_scorer</span><span class="p">,</span>
                            <span class="n">X</span><span class="o">=</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train_model</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">skfold</span><span class="p">)</span>

<span class="n">df_best_f1</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_max_depth</th>
      <th>param_n_estimators</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>2</td>
      <td>10</td>
      <td>0.253281</td>
      <td>0.009199</td>
    </tr>
    <tr>
      <th>1</th>
      <td>2</td>
      <td>50</td>
      <td>0.267678</td>
      <td>0.005953</td>
    </tr>
    <tr>
      <th>2</th>
      <td>2</td>
      <td>100</td>
      <td>0.257495</td>
      <td>0.002502</td>
    </tr>
    <tr>
      <th>3</th>
      <td>4</td>
      <td>10</td>
      <td>0.241877</td>
      <td>0.017142</td>
    </tr>
    <tr>
      <th>4</th>
      <td>4</td>
      <td>50</td>
      <td>0.257753</td>
      <td>0.014293</td>
    </tr>
    <tr>
      <th>5</th>
      <td>4</td>
      <td>100</td>
      <td>0.263571</td>
      <td>0.011393</td>
    </tr>
    <tr>
      <th>6</th>
      <td>10</td>
      <td>10</td>
      <td>0.202218</td>
      <td>0.016497</td>
    </tr>
    <tr>
      <th>7</th>
      <td>10</td>
      <td>50</td>
      <td>0.225597</td>
      <td>0.032149</td>
    </tr>
    <tr>
      <th>8</th>
      <td>10</td>
      <td>100</td>
      <td>0.230246</td>
      <td>0.025504</td>
    </tr>
    <tr>
      <th>9</th>
      <td>None</td>
      <td>10</td>
      <td>0.181869</td>
      <td>0.015010</td>
    </tr>
    <tr>
      <th>10</th>
      <td>None</td>
      <td>50</td>
      <td>0.213798</td>
      <td>0.037220</td>
    </tr>
    <tr>
      <th>11</th>
      <td>None</td>
      <td>100</td>
      <td>0.209927</td>
      <td>0.034730</td>
    </tr>
  </tbody>
</table>
</div>

<p><div align="justify">On the other hand, we are leaking our model and consequently overestimating our metric since we are choosing the best threshold in the cross-validation validation set.</div></p>

<h3 id="3-tuning-the-threshold-during-gridsearch-on-a-chunk-of-the-training-set">3. Tuning the threshold during gridsearch on a chunk of the training set</h3>

<p><div align="justify">A better way to do this (in terms of correctly evaluating the performance during cross-validation) is to modify our estimator&#39;s training function so that it also calculates the best threshold. To clarify what we are doing without having to look at the class details we will implement, it is worth comparing the difference between methods 2 and 3.</div></p>

<p><div align="justify">In each step of our cross-validation, we will have a training set and a validation set that we will use to evaluate the performance of the classifier trained in that training set. That is what we were doing in method 1, for instance.</div></p>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_30_0.png" /></center></div></p>

<p><div align="justify">In solution 2, we optimize the threshold on the validation set by taking the best possible metric value for the different thresholds of our threshold grid. But, as we are leaking the threshold search, we will overestimate our metric, which can be harmful.</div></p>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_32_0.png" /></center></div></p>

<p><div align="justify">In the solution we are discussing, during the training stage, we will do a hold-out to have a set that we will use to optimize the threshold, and the optimal threshold will be used in the validation evaluation.</div></p>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_34_0.png" /></center></div></p>

<p><div align="justify">A rough implementation of a class that does this logic is as follows:</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">inspect</span>
<span class="k">def</span> <span class="nf">dic_without_keys</span><span class="p">(</span><span class="n">dic</span><span class="p">,</span> <span class="n">keys</span><span class="p">):</span>
    <span class="k">return</span> <span class="p">{</span><span class="n">x</span><span class="p">:</span> <span class="n">dic</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">dic</span> <span class="k">if</span> <span class="n">x</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">}</span>

<span class="k">class</span> <span class="nc">ThresholdOptimizerRandomForestBinaryClassifier</span><span class="p">(</span><span class="n">RandomForestClassifier</span><span class="p">):</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_bootstrap</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">metric</span><span class="o">=</span><span class="n">f1_score</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">),</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,):</span>

        <span class="n">kwargs_without_extra</span> <span class="o">=</span> <span class="n">dic_without_keys</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">(</span><span class="s">"n_bootstrap"</span><span class="p">,</span> <span class="s">"metric"</span><span class="p">,</span> <span class="s">"threshold_grid"</span><span class="p">))</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs_without_extra</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">metric</span> <span class="o">=</span> <span class="n">metric</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">threshold_grid</span> <span class="o">=</span> <span class="n">threshold_grid</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_bootstrap</span> <span class="o">=</span> <span class="n">n_bootstrap</span>

    <span class="o">@</span><span class="nb">classmethod</span>
    <span class="k">def</span> <span class="nf">_get_param_names</span><span class="p">(</span><span class="n">cls</span><span class="p">):</span>
        <span class="n">init</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">,</span> <span class="s">"deprecated_original"</span><span class="p">,</span> <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">)</span>
        <span class="n">init_signature</span> <span class="o">=</span> <span class="n">inspect</span><span class="p">.</span><span class="n">signature</span><span class="p">(</span><span class="n">init</span><span class="p">)</span>
        <span class="n">parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">init_signature</span><span class="p">.</span><span class="n">parameters</span><span class="p">.</span><span class="n">values</span><span class="p">()</span> <span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="n">name</span> <span class="o">!=</span> <span class="s">"self"</span> <span class="ow">and</span> <span class="n">p</span><span class="p">.</span><span class="n">kind</span> <span class="o">!=</span> <span class="n">p</span><span class="p">.</span><span class="n">VAR_KEYWORD</span><span class="p">]</span>
        <span class="k">return</span> <span class="nb">sorted</span><span class="p">([</span><span class="n">p</span><span class="p">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">parameters</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="s">"n_bootstrap"</span><span class="p">,</span> <span class="s">"metric"</span><span class="p">,</span> <span class="s">"threshold_grid"</span><span class="p">])</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>

        <span class="n">X_train_model</span><span class="p">,</span> <span class="n">X_train_threshold</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">,</span> <span class="n">y_train_threshold</span> <span class="o">=</span> \
        <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">y</span><span class="p">)</span>

        <span class="nb">super</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">)</span>
        <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">best_threshold_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">optmize_threshold_metric</span><span class="p">(</span><span class="n">X_train_threshold</span><span class="p">,</span> <span class="n">y_train_threshold</span><span class="p">)</span>

        <span class="k">return</span> <span class="bp">self</span>

    <span class="k">def</span> <span class="nf">optmize_threshold_metric</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
        <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">threshold_grid</span><span class="p">:</span>
            <span class="n">metrics</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_bootstrap</span><span class="p">):</span>
                <span class="n">ind_bootstrap</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">i</span><span class="p">).</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
                <span class="n">metric_val</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">metric</span><span class="p">(</span><span class="n">y</span><span class="p">[</span><span class="n">ind_bootstrap</span><span class="p">],</span>
                                         <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">ind_bootstrap</span><span class="p">])[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">t</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
                <span class="n">metrics</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">metric_val</span><span class="p">)</span>
            <span class="n">metric_means</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">metrics</span><span class="p">))</span>
            <span class="n">metric_stds</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">std</span><span class="p">(</span><span class="n">metrics</span><span class="p">))</span>

        <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">metric_means</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">metric_stds</span><span class="p">)</span>
        <span class="n">best_threshold</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">threshold_grid</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">metric_means</span><span class="p">)]</span>

        <span class="k">return</span> <span class="n">metric_means</span><span class="p">,</span> <span class="n">metric_stds</span><span class="p">,</span> <span class="n">best_threshold</span>

    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
        <span class="k">return</span> <span class="p">(</span><span class="n">preds</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="p">.</span><span class="n">best_threshold_</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">$\oint$ <em><a href="https://scikit-learn.org/stable/developers/develop.html#instantiation">scikit-learn doesn&#39;t like you using <code>args</code> and <code>kwargs</code> on your estimator&#39;s <code>init</code></a> because of how they designed the way they deal with hyperparameter optimization. But as I didn&#39;t want my <code>init</code> to <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py#L52-L71">look like this</a>, I decided to change the <a href="https://github.com/scikit-learn/scikit-learn/blob/98cf537f5c538fdbc9d27b851cf03ce7611b8a48/sklearn/base.py#L122-L151"><code>_get_param_names</code></a> from the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html"><code>sklearn.base.BaseEstimator</code></a> to call only the parameters of the class I&#39;m inheriting from (<a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a>, a.k.a. <code>super()</code>). If you want to design it properly, you should do <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py#L52-L71">this</a>.</em></div></p>

<p><div align="justify">$\oint$ <em>Note that although I&#39;m inheriting from <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a>, I don&#39;t use any <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a>-specific logic here, and actually, you can do the same with any scikit-learn estimator.</em></div></p>

<p><div align="justify">We are basically using the same optimization function we had discussed earlier on the part of the set that is given in <code>.fit</code> by doing a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html"><code>sklearn.model_selection.train_test_split</code></a>. This implementation is computationally expensive, mainly because of bootstrap. So we can lower the number of bootstrap samples to make it faster.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">time</span>

<span class="n">df_best</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span>
    <span class="n">estimator</span><span class="o">=</span><span class="n">ThresholdOptimizerRandomForestBinaryClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_bootstrap</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
                                                             <span class="n">metric</span><span class="o">=</span><span class="n">f1_score</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="o">=</span><span class="n">threshold_grid</span><span class="p">),</span>
    <span class="n">scoring</span><span class="o">=</span><span class="s">"f1"</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train_model</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">skfold</span><span class="p">)</span>

<span class="n">df_best</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>CPU times: total: 5min 25s
Wall time: 5min 28s
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_max_depth</th>
      <th>param_n_estimators</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>2</td>
      <td>10</td>
      <td>0.238970</td>
      <td>0.011282</td>
    </tr>
    <tr>
      <th>1</th>
      <td>2</td>
      <td>50</td>
      <td>0.238447</td>
      <td>0.016450</td>
    </tr>
    <tr>
      <th>2</th>
      <td>2</td>
      <td>100</td>
      <td>0.243230</td>
      <td>0.022790</td>
    </tr>
    <tr>
      <th>3</th>
      <td>4</td>
      <td>10</td>
      <td>0.203598</td>
      <td>0.039442</td>
    </tr>
    <tr>
      <th>4</th>
      <td>4</td>
      <td>50</td>
      <td>0.226371</td>
      <td>0.023246</td>
    </tr>
    <tr>
      <th>5</th>
      <td>4</td>
      <td>100</td>
      <td>0.249048</td>
      <td>0.007759</td>
    </tr>
    <tr>
      <th>6</th>
      <td>10</td>
      <td>10</td>
      <td>0.200635</td>
      <td>0.034000</td>
    </tr>
    <tr>
      <th>7</th>
      <td>10</td>
      <td>50</td>
      <td>0.199724</td>
      <td>0.050758</td>
    </tr>
    <tr>
      <th>8</th>
      <td>10</td>
      <td>100</td>
      <td>0.176026</td>
      <td>0.042777</td>
    </tr>
    <tr>
      <th>9</th>
      <td>None</td>
      <td>10</td>
      <td>0.175387</td>
      <td>0.015105</td>
    </tr>
    <tr>
      <th>10</th>
      <td>None</td>
      <td>50</td>
      <td>0.158617</td>
      <td>0.015450</td>
    </tr>
    <tr>
      <th>11</th>
      <td>None</td>
      <td>100</td>
      <td>0.179195</td>
      <td>0.036804</td>
    </tr>
  </tbody>
</table>
</div>

<hr />

<h2 id="tuning-the-threshold-for-the-best-hyperparameters-combination">Tuning the threshold for the best hyperparameters combination</h2>

<p><div align="justify">With this best combination of hyperparameters of method 3 chosen, we can do the procedure we discussed earlier to find the best threshold for this model.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">best_params_values</span> <span class="o">=</span> <span class="n">df_best</span><span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="s">"mean_test_score"</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">).</span><span class="n">iloc</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="s">"param_"</span> <span class="o">+</span> <span class="n">x</span><span class="p">,</span>  <span class="n">params</span><span class="p">.</span><span class="n">keys</span><span class="p">()))].</span><span class="n">values</span>
<span class="n">best_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">params</span><span class="p">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">best_params_values</span><span class="p">))</span>
<span class="n">best_params</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>{'max_depth': 4, 'n_estimators': 100}
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">best_model</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="p">.</span><span class="n">set_params</span><span class="p">(</span><span class="o">**</span><span class="n">best_params</span><span class="p">)</span>
    <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train_model</span><span class="p">,</span> <span class="n">y_train_model</span><span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">f1_means_best</span><span class="p">,</span> <span class="n">f1_stds_best</span><span class="p">,</span> <span class="n">best_threshold_best</span> <span class="o">=</span> \
<span class="n">optmize_threshold_metric</span><span class="p">(</span><span class="n">best_model</span><span class="p">,</span> <span class="n">X_train_threshold</span><span class="p">,</span> <span class="n">y_train_threshold</span><span class="p">,</span> <span class="n">f1_score</span><span class="p">,</span> <span class="n">threshold_grid</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">threshold_grid</span><span class="p">,</span> <span class="n">f1_means_best</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">threshold_grid</span><span class="p">,</span> <span class="n">f1_means_best</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_best</span><span class="p">,</span> <span class="n">f1_means_best</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_best</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">vlines</span><span class="p">(</span><span class="n">best_threshold_best</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">f1_means_best</span> <span class="o">-</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_best</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">f1_means_best</span> <span class="o">+</span> <span class="mf">1.96</span> <span class="o">*</span> <span class="n">f1_stds_best</span><span class="p">),</span> <span class="s">"k"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Chosen threshold"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Threshold"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"f1_score"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>100%|██████████| 101/101 [01:13&lt;00:00,  1.37it/s]
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/threshold_dependent_opt/output_43_1.png" /></center></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">f1_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="p">(</span><span class="n">best_model</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">best_threshold_best</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.24038461538461534
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">f1_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">best_model</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.0
</code></pre></div></div>

<p><div align="justify">Notice that we got a much better <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"><code>sklearn.metrics.f1_score</code></a> than the initial search was telling us we would get!</div></p>

<hr />

<h2 id="tldr">tl;dr</h2>

<p><div align="justify">When optimizing hyperparameters, threshold-dependent metrics make <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a>-like search methods use the estimator&#39;s <code>.predict</code> method instead of <code>.predict_proba</code>. This can be harmful as <code>0.5</code> might not be the best threshold, especially in imbalanced learning scenarios.</div></p>

<p><div align="justify">Always prioritize the threshold-independent metrics, but if you need to use a threshold-dependent metric, you can try to make it threshold-independent by getting the optimal value for it (<code>max</code> or <code>min</code> depending on if <code>greater_is_better=True</code> or <code>False</code>) for a threshold grid of options. As this is the same as optimizing it for the validation set, it can slightly overestimate your results.</div></p>

<p><div align="justify">A more honest way to do this is to explicitly optimize the threshold on a part of your training set for each cross-validation fold. This mimics reality better but is more time-consuming as this optimization takes time if you want it to be robust (for instance, using bootstrap to better estimate the performance value).</div></p>

<p><div align="justify">$\oint$ <em>This is the current state of this topic, in version 1.2.0 of scikit-learn. In a future release, there will be a <code>sklearn.model_selection.CutoffClassifier</code> (from <a href="https://github.com/scikit-learn/scikit-learn/pull/16525">PR #16525</a>) that will behave very closely to my <code>ThresholdOptimizerRandomForestBinaryClassifier</code>. One significant change will be that it will receive the estimator during initialization instead of inheriting it.</em></div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p><div align="justify">[1] <a href="https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/">A Gentle Introduction to Threshold-Moving for Imbalanced Classification by Jason Brownlee.</a></div></p>

<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/threshold_dependent_opt">repository of this post</a>. In addition, I recorded a <a href="https://youtu.be/I6WDGNC_YJQ">video version</a> of this post in Portuguese.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;imbalanced learning&quot;]" /><summary type="html"><![CDATA[It can be dangerous to do hyperparameter tunning with threshold-dependent metrics directly. Here we discuss why and how to work around it.]]></summary></entry><entry><title type="html">Meta K-Means: um ensemble de K-Means</title><link href="https://vitaliset.github.io/metakmeans/" rel="alternate" type="text/html" title="Meta K-Means: um ensemble de K-Means" /><published>2022-10-23T00:00:00+00:00</published><updated>2022-10-23T00:00:00+00:00</updated><id>https://vitaliset.github.io/metakmeans</id><content type="html" xml:base="https://vitaliset.github.io/metakmeans/"><![CDATA[<p><div align="justify">Após ouvir falar superficialmente sobre comitês de algoritmos de clusterização [<a href="#bibliography">1</a>], me perguntei: qual seria um jeito esperto de agregar as decisões individuais de cada um dos clusters em um valor final? A resposta não é imediata, principalmente porque o problema aqui é que a definição de cada cluster pode ser diferente mesmo quando eles concordam nas separações.</div></p>

<p><div align="justify">Por exemplo, dado um conjunto de oito exemplos, as segmentações <code>[0, 0, 1, 0, 2, 2, 2, 1]</code> e <code>[1, 1, 0, 1, 3, 3, 3, 0]</code> são idênticas a menos de uma permutação de nomes, isto é, basta chamar o cluster 0 de 1 e o 1 de 0 em alguma das listas e o 3 de 2 na segunda lista (ou o 2 de 3 na primeira lista). É importante ter clareza de que esses clusters de fato concordam, uma vez que a nomenclatura  não tem significado algum já que não estamos num problema de classificação.</div></p>

<p><div align="justify"><center><img src="/assets/img/metakmeans/output_3_0.png" /></center></div></p>

<p><div align="justify">Isso motiva a criação de métricas de &quot;avaliação de clusterização&quot; como a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.rand_score.html"><code>sklearn.metrics.rand_score</code></a> que responde a pergunta: o quão similar são duas clusterizações? Em que, obter o valor próximo de 1 significa que os agrupamentos concordam bastante (a menos de possíveis trocas de nomes).</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">rand_score</span>

<span class="n">rand_score</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>1.0
</code></pre></div></div>

<p><div align="justify">$\oint$ <em>A ideia por trás do <a href="https://en.wikipedia.org/wiki/Rand_index">(unadjusted) rand index</a> é bem intuitiva e para explicar, vamos pensar em um exemplo específico. Imagine o cenário em que temos um conjunto de dados <code>[a, b, c, d]</code> e duas clusterizações possíveis: <code>A = [1, 1, 0, 0]</code> e <code>B = [1, 1, 1, 2]</code>.</em></div></p>

<ol>
  <li>
    <p><div align="justify"><em>Primeiro, separamos todos os pares possíveis de elementos que temos no nosso conjunto. No nosso exemplo teríamos <code>(a, b)</code>, <code>(a, c)</code>, <code>(a, d)</code>, <code>(b, c)</code>, <code>(b, d)</code> e <code>(c, d)</code>.</em></div></p>
  </li>
  <li>
    <p><div align="justify"><em>Em seguida, contabilizamos quantos desses pares concordam nas clusterizações <code>A</code> e <code>B</code>. Concordar nas clusterizações significa que estão no mesmo cluster ao mesmo tempo, tanto em <code>A</code> quanto em <code>B</code>, ou não estão no mesmo cluster ao mesmo tempo nas duas clusterizações. No nosso caso, o par <code>(a, b)</code> concorda porque, tanto em <code>A</code> quanto em <code>B</code>, ambos estão no mesmo cluster. Mas também os pares <code>(a, d)</code> e <code>(b, d)</code> concordam nas duas clusterizações porque são alocados em clusters diferentes simultaneamente.</em></div></p>
  </li>
  <li>
    <p><div align="justify"><em>Com o número de pares concordantes, fazemos a razão pelo número total de pares para ter o valor do unadjusted rand index calculado, nossa medida de similaridade entre agrupamentos. No nosso caso, <code>3/6=0.5</code>.</em></div></p>
  </li>
</ol>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">rand_score</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.5
</code></pre></div></div>

<p><div align="justify">Essas permutações deixam o problema extremamente mais desafiador do que temos num comitê supervisionado e existe uma literatura extensa [<a href="#bibliography">1</a>] que tenta abordá-lo uma vez que gostaríamos de poder utilizar ideias de comitê também aqui.</div></p>

<p><div align="justify">Conversando com o <a href="https://www.linkedin.com/in/atmg92/">Alessandro</a>, tentamos encarar esse problema em uma versão mais compacta dele, analisando o caso específico de comitê de <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> (apesar de mais simples, ainda assim seria um caso com possível ganho prático pela popularidade do método). A hipótese seria de que é possível utilizar os centróides para achar as concordâncias entre os diferentes estimadores individuais e daí surgiu a ideia de clusterizar os centróides dos <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> individuais para renomear os clusters finais de uma maneira única entre os diferentes estimadores individuais.</div></p>

<p><div align="justify">Para exemplificar a ideia, um exemplo ajuda: se temos dois <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> com <code>n_clusters=3</code>, então teríamos três centróides $K_1, K_2, K_3$ associados ao primeiro <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> e os centróides $C_1, C_2, C_3$ do segundo <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a>. Se, ao clusterizar (com o mesmo número de clusters <code>n_clusters</code>), encontrássemos os metaclusters $G_1 = \{ K_1, C_1 \}$, $G_2 = \{ K_2, K_3, C_3\}$ e $G_3 = \{ C_2\}$, então teríamos um mapeamento na hora de agregar o resultado dos diferentes <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> individuais.</div></p>

<p><div align="justify">Um exemplo que cai no cluster do centróide $K_1$ no primeiro agrupamento e no de $C_3$ no segundo é associado ao cluster $G_1$ com peso $1/2=0.5$ (já que um de dois K-Means base associou-o a esse grupo), ao cluster $G_2$ com peso $1/2=0.5$ (já que um de dois K-Means base associou-o a esse grupo) e ao cluster $G_3$ com peso $0/2=0$ (já que nenhum dos dois K-Means base associou-o a esse grupo). Já um exemplo que cai em $K_3$ e $C_3$ nos agrupamentos individuais estaria associado ao grupo $G_2$ com peso $2/2=1$, enquanto nos outros $G_i$ com peso $0$. Outros casos são análogos. Nesse formato, estamos voltando à mesma ideia de uma votação de um comitê clássico de classificação para criar um índice de pertencimento de cada exemplo em cada cluster como um algoritmo de <a href="https://en.wikipedia.org/wiki/Fuzzy_clustering">soft clustering</a>.</div></p>

<hr />

<h2 id="testando-a-ideia-no-dataset-de-dígitos">Testando a ideia no dataset de dígitos</h2>

<p><div align="justify">Para fazer um experimento com esse modelo, vamos brincar com o conjunto de imagens de baixa resolução de dígitos escritos à mão que podemos carregar usando a função <a href="https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html"><code>sklearn.datasets.load_digits</code></a>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">load_digits</span>

<span class="n">digits</span> <span class="o">=</span> <span class="n">load_digits</span><span class="p">(</span><span class="n">n_class</span><span class="o">=</span><span class="mi">9</span><span class="p">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">digits</span><span class="p">.</span><span class="n">data</span>
<span class="n">X</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(1617, 64)
</code></pre></div></div>

<p><div align="justify">Para introduzir variância nos clusters individuais e eles não concordarem totalmente (a menos de alguma permutação), podemos tanto mudar a estratégia de treinamento do <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> (por exemplo, diminuindo o número de inicializações que ele faz para encontrar a melhor partição em termos de inércia), quanto fazer um bootstrap do nosso conjunto de treino (inspirado em como um <a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating">bagging</a> funciona no caso supervisionado). Nesse experimento, estamos seguindo com a segunda opção.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.cluster</span> <span class="kn">import</span> <span class="n">KMeans</span>

<span class="n">n_estimators</span> <span class="o">=</span> <span class="mi">250</span>
<span class="n">n_clusters</span> <span class="o">=</span> <span class="mi">9</span>

<span class="n">km_list</span> <span class="o">=</span> \
<span class="p">[</span><span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="n">n_clusters</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">i</span><span class="p">)</span>
 <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">i</span><span class="p">).</span><span class="n">choice</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])])</span> 
 <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_estimators</span><span class="p">))]</span>
</code></pre></div></div>

<p><div align="justify">Após treinar os diferentes <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a>, precisamos treinar o &quot;Meta K-Means&quot; que utilizará os centróides para treinamento.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">cluster_centers</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">km</span><span class="p">.</span><span class="n">cluster_centers_</span> <span class="k">for</span> <span class="n">km</span> <span class="ow">in</span> <span class="n">km_list</span><span class="p">])</span>

<span class="n">meta_kmeans</span> <span class="o">=</span> <span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="n">n_clusters</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">cluster_centers</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Desse modo, conseguimos construir os mapeamentos que agrupam os centróides fazendo a tradução dos clusters individuais de forma que eles concordem de acordo com o critério de agrupamento do "Meta K-Means".</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">meta_clusters_map</span> <span class="o">=</span> \
<span class="p">[{</span><span class="n">j</span><span class="p">:</span> <span class="n">meta_kmeans</span><span class="p">.</span><span class="n">labels_</span><span class="p">[</span><span class="n">n_clusters</span><span class="o">*</span><span class="n">i</span><span class="o">+</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_clusters</span><span class="p">)}</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_estimators</span><span class="p">)]</span>
</code></pre></div></div>

<p><div align="justify">Para fazer o agrupamento dos clusters individuais, fazemos algum tipo de agrupamento (como a média, pensando em uma votação simples) dos diferentes clusters para obter um índice de pertencimento de cada exemplo a cada cluster.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">LabelBinarizer</span>

<span class="n">lb</span> <span class="o">=</span> <span class="n">LabelBinarizer</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_clusters</span><span class="p">)))</span>

<span class="n">aggregated_predicts</span> <span class="o">=</span> \
<span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">lb</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">map_dic</span><span class="p">.</span><span class="n">get</span><span class="p">,</span> <span class="n">km</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X</span><span class="p">)))))</span>
          <span class="k">for</span> <span class="n">km</span><span class="p">,</span> <span class="n">map_dic</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">km_list</span><span class="p">,</span> <span class="n">meta_clusters_map</span><span class="p">)]).</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

<span class="n">aggregated_predicts</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(1617, 9)
</code></pre></div></div>

<p><div align="justify">Para analisar se o que encontramos parece fazer sentido, vamos tentar interpretar os metacentróides encontrados (ou seja, os centróides que encontramos quando rodamos o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> nos centróides dos <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> base). Como estamos mexendo com essa base de dígitos, podemos olhar para a imagem representada pelo plot do metacentróide de cada cluster final.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">nrows</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>

<span class="n">plt</span><span class="p">.</span><span class="n">gray</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)):</span>
    <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">].</span><span class="n">matshow</span><span class="p">(</span><span class="n">meta_kmeans</span><span class="p">.</span><span class="n">cluster_centers_</span><span class="p">[</span><span class="mi">3</span><span class="o">*</span><span class="n">i</span><span class="o">+</span><span class="n">j</span><span class="p">].</span><span class="n">reshape</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">))</span>
    <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">].</span><span class="n">set_xticks</span><span class="p">([])</span>
    <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">].</span><span class="n">set_yticks</span><span class="p">([])</span>
    <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Cluster </span><span class="si">{</span><span class="mi">3</span><span class="o">*</span><span class="n">i</span><span class="o">+</span><span class="n">j</span><span class="si">}</span><span class="s"> centroid"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/metakmeans/output_20_0.png" /></center></div></p>

<p><div align="justify">A inspeção visual nos permite dar nomes para os clusters seguindo o formato dos números, construindo o seguinte dicionário:</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dict_cluster</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">4</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">5</span><span class="p">:</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">:</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">7</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">:</span> <span class="mi">7</span><span class="p">}</span>
</code></pre></div></div>

<p><div align="justify">Para ver os clusters finais e em que regiões do espaço estão os nossos pontos associados a clusters incertos, vamos aplicar um <a href="https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html"><code>sklearn.manifold.MDS</code></a> e, em seguida, um <a href="https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html"><code>sklearn.manifold.TSNE</code></a> para reduzir a dimensionalidade dos nossos dados.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.manifold</span> <span class="kn">import</span> <span class="n">MDS</span><span class="p">,</span> <span class="n">TSNE</span>

<span class="n">X_emb</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">TSNE</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">MDS</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">X</span><span class="p">)))</span>
</code></pre></div></div>

<p><div align="justify">É legal ver que nossos clusters estão fazendo sentido com a marcação original de dígitos, mas o gráfico mais importante aqui é o último: vemos que de fato, existem exemplos que parecem ser mais confusos de atribuir a algum cluster de forma certa (como as imagens associadas ao número 8 que são facilmente confundidas com outros números e exemplos que parecem estar "na fronteira", entre dois agrupamentos).</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>

<span class="n">im0</span> <span class="o">=</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">digits</span><span class="p">.</span><span class="n">target</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">"Set1"</span><span class="p">)</span>
<span class="n">cbar0</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">im0</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ticks</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">7.5</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">cbar0</span><span class="p">.</span><span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Real number class"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>

<span class="n">im1</span> <span class="o">=</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                    <span class="n">c</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">dict_cluster</span><span class="p">.</span><span class="n">get</span><span class="p">,</span> <span class="n">aggregated_predicts</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))),</span>
                    <span class="n">cmap</span><span class="o">=</span><span class="s">"Set1"</span><span class="p">)</span>
<span class="n">cbar1</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">im1</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">ticks</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">7.5</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">cbar1</span><span class="p">.</span><span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Cluster class"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>

<span class="n">cmap2</span> <span class="o">=</span> <span class="n">colors</span><span class="p">.</span><span class="n">ListedColormap</span><span class="p">([</span><span class="s">"#e41a1c"</span><span class="p">,</span> <span class="s">"#4daf4a"</span><span class="p">])</span>
<span class="n">im2</span> <span class="o">=</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                    <span class="n">c</span><span class="o">=</span><span class="p">(</span><span class="n">aggregated_predicts</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">==</span><span class="mi">1</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cmap2</span><span class="p">)</span>
<span class="n">im2</span><span class="p">.</span><span class="n">set_clim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">cbar2</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">im2</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">ticks</span><span class="o">=</span><span class="p">[</span><span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.75</span><span class="p">])</span>
<span class="n">cbar2</span><span class="p">.</span><span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">([</span><span class="s">"Some uncertainty"</span><span class="p">,</span> <span class="s">"No uncertainty"</span><span class="p">],</span>
                         <span class="n">rotation</span><span class="o">=</span><span class="mi">270</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s">"center"</span><span class="p">,</span> <span class="n">rotation_mode</span><span class="o">=</span><span class="s">"anchor"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">9</span><span class="p">)</span>
<span class="n">cbar2</span><span class="p">.</span><span class="n">ax</span><span class="p">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">pad</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Certainty about the assigned cluster"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>

<span class="n">cmap3</span> <span class="o">=</span> <span class="n">colors</span><span class="p">.</span><span class="n">LinearSegmentedColormap</span><span class="p">.</span><span class="n">from_list</span><span class="p">(</span><span class="s">''</span><span class="p">,</span> <span class="n">colors</span><span class="o">=</span><span class="p">[</span><span class="s">"#e41a1c"</span><span class="p">,</span> <span class="s">"#4daf4a"</span><span class="p">])</span>
<span class="n">im3</span> <span class="o">=</span> <span class="n">ax</span><span class="p">[</span><span class="mi">3</span><span class="p">].</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X_emb</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                    <span class="n">c</span><span class="o">=</span><span class="n">aggregated_predicts</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cmap3</span><span class="p">,</span> <span class="n">norm</span><span class="o">=</span><span class="n">colors</span><span class="p">.</span><span class="n">LogNorm</span><span class="p">())</span>
<span class="n">im3</span><span class="p">.</span><span class="n">set_clim</span><span class="p">(</span><span class="mf">0.73</span><span class="p">,</span> <span class="mf">1.02</span><span class="p">)</span>
<span class="n">cbar3</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">(</span><span class="n">im3</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">ticks</span><span class="o">=</span><span class="p">[</span><span class="mf">0.75</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.85</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">cbar3</span><span class="p">.</span><span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">([</span><span class="s">'$\leq$0.75'</span><span class="p">,</span> <span class="s">'0.80'</span><span class="p">,</span> <span class="s">'0.85'</span><span class="p">,</span> <span class="s">'0.9'</span><span class="p">,</span> <span class="s">'0.95'</span><span class="p">,</span> <span class="s">'1.00'</span><span class="p">])</span>
<span class="n">ax</span><span class="p">[</span><span class="mi">3</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="s">'Maximum of "predict_proba"'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>

<span class="k">for</span> <span class="n">axs</span> <span class="ow">in</span> <span class="n">ax</span><span class="p">:</span>
    <span class="n">clean_axes</span><span class="p">(</span><span class="n">axs</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/metakmeans/output_26_0.png" /></center></div></p>

<p><div align="justify">Observando o histograma do máximo do nosso &quot;<code>.predict_proba</code>&quot;, vemos que para um número razoável de exemplos, os clusters encontrados pelos agrupamentos individuais podem discordar ligeiramente gerando uma visão de incerteza e robustez associada à sua atribuição de agrupamento (ideia central dos algoritmos de <a href="https://en.wikipedia.org/wiki/Fuzzy_clustering">soft clustering</a>). Entretanto, para maioria dos exemplos os <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> individuais concordam totalmente.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span><span class="n">aggregated_predicts</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">25</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_yscale</span><span class="p">(</span><span class="s">"log"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'Maximum of "predict_proba" per instance'</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Frequency (log scale)"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Histogram of assigned cluster certainty"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/metakmeans/output_28_0.png" /></center></div></p>

<p><div align="justify">Essa visão nos permite ver os exemplos mais difíceis de agrupar, dando uma noção de <a href="https://deslib.readthedocs.io/en/latest/modules/util/instance_hardness.html">instance hardness</a> para o nosso problema de clusterização que, no nosso exemplo, parece estar associado a números parecidos com o 8.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">aggregated_predicts</span><span class="p">)[(</span><span class="n">aggregated_predicts</span><span class="o">&lt;</span><span class="mf">0.45</span><span class="p">).</span><span class="nb">all</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
 <span class="p">.</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="n">dict_cluster</span><span class="p">).</span><span class="n">T</span><span class="p">.</span><span class="n">sort_index</span><span class="p">().</span><span class="n">T</span><span class="p">)</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>0</th>
      <th>1</th>
      <th>2</th>
      <th>3</th>
      <th>4</th>
      <th>5</th>
      <th>6</th>
      <th>7</th>
      <th>8</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>630</th>
      <td>0.0</td>
      <td>0.084</td>
      <td>0.06</td>
      <td>0.408</td>
      <td>0.0</td>
      <td>0.000</td>
      <td>0.000</td>
      <td>0.000</td>
      <td>0.448</td>
    </tr>
    <tr>
      <th>1385</th>
      <td>0.0</td>
      <td>0.204</td>
      <td>0.00</td>
      <td>0.164</td>
      <td>0.0</td>
      <td>0.196</td>
      <td>0.000</td>
      <td>0.424</td>
      <td>0.012</td>
    </tr>
    <tr>
      <th>1386</th>
      <td>0.0</td>
      <td>0.088</td>
      <td>0.00</td>
      <td>0.060</td>
      <td>0.0</td>
      <td>0.228</td>
      <td>0.000</td>
      <td>0.312</td>
      <td>0.312</td>
    </tr>
    <tr>
      <th>1533</th>
      <td>0.0</td>
      <td>0.076</td>
      <td>0.00</td>
      <td>0.388</td>
      <td>0.0</td>
      <td>0.000</td>
      <td>0.196</td>
      <td>0.000</td>
      <td>0.340</td>
    </tr>
    <tr>
      <th>1616</th>
      <td>0.0</td>
      <td>0.032</td>
      <td>0.00</td>
      <td>0.420</td>
      <td>0.0</td>
      <td>0.000</td>
      <td>0.308</td>
      <td>0.000</td>
      <td>0.240</td>
    </tr>
  </tbody>
</table>
</div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">ncols</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">))</span>

<span class="n">plt</span><span class="p">.</span><span class="n">gray</span><span class="p">()</span>
<span class="k">for</span> <span class="n">axs</span><span class="p">,</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">aggregated_predicts</span><span class="p">)[(</span><span class="n">aggregated_predicts</span><span class="o">&lt;</span><span class="mf">0.45</span><span class="p">).</span><span class="nb">all</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)].</span><span class="n">index</span><span class="p">):</span>
    <span class="n">axs</span><span class="p">.</span><span class="n">matshow</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">reshape</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span><span class="mi">8</span><span class="p">))</span>
    <span class="n">axs</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">([])</span>
    <span class="n">axs</span><span class="p">.</span><span class="n">set_yticks</span><span class="p">([])</span>
    <span class="n">axs</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s"> - Target: </span><span class="si">{</span><span class="n">digits</span><span class="p">.</span><span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>
<p><div align="justify"><center><img src="/assets/img/metakmeans/output_31_0.png" /></center></div></p>

<hr />

<h2 id="considerações-finais">Considerações finais</h2>

<p><div align="justify">Essa ideia de clusterização de centróides não é nova e, inclusive, pode ser utilizada para definir a inicialização do K-Means. Esse algoritmo é chamado Refined K-Means [<a href="#bibliography">1</a>], entretanto não parece ter uma vantagem clara quando comparado ao <a href="https://en.wikipedia.org/wiki/K-means%2B%2B">K-Means++</a> com múltiplas inicializações (maneira como o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html"><code>sklearn.cluster.KMeans</code></a> segue).</div></p>

<p><div align="justify">Apesar de claramente ter aplicações em que vale a pena testar essa visão, nos experimentos feitos para construir essa discussão, os clusters encontrados individualmente raramente discordam muito (conseguimos ver isso pelo número significativo de exemplos com <code>aggregated_predicts.max(axis=1)</code> sendo igual a 1) e os hard clusters encontrados no final da nossa estratégia de soft clustering (pegando o <code>.argmax</code>) são muito parecidos com os clusters encontrados em um K-Means usual. Portanto, não acho que seja uma técnica extremamente promissora, apesar de valer o teste sempre que você estiver interessado em um K-Means pelo baixo esforço adicional.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">unique_km_labels</span> <span class="o">=</span> <span class="n">KMeans</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">).</span><span class="n">labels_</span>

<span class="p">(</span><span class="n">rand_score</span><span class="p">(</span><span class="n">unique_km_labels</span><span class="p">,</span> <span class="n">aggregated_predicts</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)),</span>
 <span class="p">(</span><span class="n">aggregated_predicts</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">==</span><span class="mi">1</span><span class="p">).</span><span class="n">mean</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(0.9799745280650514, 0.6951144094001237)
</code></pre></div></div>

<p><div align="justify">Por fim, é fácil generalizar as ideias aqui para qualquer outro algoritmo de clusterização baseado em centróides como o <a href="https://en.wikipedia.org/wiki/K-medians_clustering">K-Medians</a> ou o <a href="https://en.wikipedia.org/wiki/K-medians_clustering">K-Medoids</a>. Isso significa que não estamos necessariamente presos à <a href="https://vitaliset.github.io/distancia/">distância euclidiana</a>, que é a <a href="https://stats.stackexchange.com/questions/81481/why-does-k-means-clustering-algorithm-use-only-euclidean-distance-metric">distância utilizada pelo K-Means</a>.</div></p>

<hr />

<h2 id="implementação-grosseira-da-classe-do-estimador">Implementação grosseira da classe do estimador</h2>

<p><div align="justify">Se você estiver interessado em utilizar essas ideias, elas deveriam funcionar utilizando algo na linha da classe implementada a seguir, que é compatível com bibliotecas que seguem o <a href="https://scikit-learn.org/stable/developers/develop.html">padrão de código do scikit-learn</a>. Apenas fique atento ao caso em que <code>n_clusters=2</code>, pois o <code>sklearn.preprocessing.LabelBinarizer</code> mantém apenas uma coluna ao invés de criar duas e, nesse caso, o return do seu <code>.predict_proba</code> terá apenas uma dimensão.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">sklearn.base</span> <span class="kn">import</span> <span class="n">BaseEstimator</span>
<span class="kn">from</span> <span class="nn">sklearn.cluster</span> <span class="kn">import</span> <span class="n">KMeans</span>
<span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">LabelBinarizer</span>

<span class="k">class</span> <span class="nc">MetaKMeans</span><span class="p">(</span><span class="n">BaseEstimator</span><span class="p">):</span>
    <span class="s">"""Meta K-Means clustering.

    A Meta K-Means is a meta estimator that fits several K-Means
    on various sub-samples of the dataset and uses averaging to
    measure uncertainty related to predicted clusters.

    Parameters
    ----------
    n_clusters : int, default=8
        The number of clusters to form as well as the number of
        metacentroids to generate.

    n_estimators : int, default=100
        The number of K-Means in the ensemble.

    random_state : int, default=42
        Controls both the randomness of the bootstrapping of the samples used
        when building the individual K-Means and the randomness of the
        choice of initial centroids of each K-Means.

    KMeans_params : dict, default={}
        Explicitly set some of the base K-Means parameters as **KMeans_params.
    """</span>
    
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_clusters</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_estimators</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">,</span> <span class="n">KMeans_params</span><span class="o">=</span><span class="p">{}):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span> <span class="o">=</span> <span class="n">n_clusters</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_estimators</span> <span class="o">=</span> <span class="n">n_estimators</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">random_state</span> <span class="o">=</span> <span class="n">random_state</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">KMeans_params</span> <span class="o">=</span> <span class="n">KMeans_params</span>

    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">estimators_</span> <span class="o">=</span> \
        <span class="p">[</span><span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">i</span><span class="o">+</span><span class="bp">self</span><span class="p">.</span><span class="n">random_state</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="p">.</span><span class="n">KMeans_params</span><span class="p">)</span>
         <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">i</span><span class="p">).</span><span class="n">choice</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])])</span> 
         <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_estimators</span><span class="p">)]</span>
        
        <span class="n">cluster_centers</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">km</span><span class="p">.</span><span class="n">cluster_centers_</span> <span class="k">for</span> <span class="n">km</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">estimators_</span><span class="p">])</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">meta_kmeans_</span> <span class="o">=</span> <span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">cluster_centers</span><span class="p">)</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">metacluster_centers_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">meta_kmeans_</span><span class="p">.</span><span class="n">cluster_centers_</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">meta_clusters_map_</span> <span class="o">=</span> \
        <span class="p">[{</span><span class="n">j</span><span class="p">:</span> <span class="bp">self</span><span class="p">.</span><span class="n">meta_kmeans_</span><span class="p">.</span><span class="n">labels_</span><span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span><span class="o">*</span><span class="n">i</span><span class="o">+</span><span class="n">j</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span><span class="p">)}</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_estimators</span><span class="p">)]</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">lb_</span> <span class="o">=</span> <span class="n">LabelBinarizer</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_clusters</span><span class="p">)))</span>
        
        <span class="k">return</span> <span class="bp">self</span>
    
    <span class="k">def</span> <span class="nf">predict_proba</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="k">return</span> \
        <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="bp">self</span><span class="p">.</span><span class="n">lb_</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">map_dic</span><span class="p">.</span><span class="n">get</span><span class="p">,</span> <span class="n">km</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X</span><span class="p">)))))</span>
                  <span class="k">for</span> <span class="n">km</span><span class="p">,</span> <span class="n">map_dic</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">estimators_</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">meta_clusters_map_</span><span class="p">)]).</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">).</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">class_meta_kmeans_with_params</span> <span class="o">=</span> \
<span class="n">MetaKMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">n_estimators</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">KMeans_params</span><span class="o">=</span><span class="p">{</span><span class="s">"init"</span><span class="p">:</span> <span class="s">"random"</span><span class="p">}).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>

<span class="n">class_meta_kmeans</span> <span class="o">=</span> \
<span class="n">MetaKMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">n_estimators</span><span class="o">=</span><span class="mi">250</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
<span class="n">class_predict_probas</span> <span class="o">=</span> <span class="n">class_meta_kmeans</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>

<span class="c1"># As I'm choosing the same random_state, I expect results of the class
# to match the ones we did above.
</span><span class="p">((</span><span class="n">class_predict_probas</span> <span class="o">==</span> <span class="n">aggregated_predicts</span><span class="p">).</span><span class="nb">all</span><span class="p">(),</span>
 <span class="p">(</span><span class="n">class_meta_kmeans</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="o">==</span> <span class="n">aggregated_predicts</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)).</span><span class="nb">all</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(True, True)
</code></pre></div></div>

<h2 id="referências"><a name="bibliography">Referências</a></h2>

<p><div align="justify">[1] <a href="https://www.sciencedirect.com/science/article/abs/pii/S1574013717300692">Cluster ensembles: A survey of approaches with recent extensions and applications. Tossapon Boongoen Natthakan Iam-On. Computer Science Review Volume 28, 2018.</a></div></p>

<hr />

<p><div align="justify">Todos os arquivos e ambiente para reprodução dos experimentos podem ser encontrados no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/metakmeans">repositório deste post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇧🇷&quot;, &quot;clustering&quot;]" /><summary type="html"><![CDATA[Uma possível maneira de agregar resultado de diferentes K-Means para construir um comitê.]]></summary></entry><entry><title type="html">Uma utilização crítica do Boruta</title><link href="https://vitaliset.github.io/boruta/" rel="alternate" type="text/html" title="Uma utilização crítica do Boruta" /><published>2022-09-05T00:00:00+00:00</published><updated>2022-09-05T00:00:00+00:00</updated><id>https://vitaliset.github.io/boruta</id><content type="html" xml:base="https://vitaliset.github.io/boruta/"><![CDATA[<p><div align="justify">Se fixarmos o poder preditivo no conjunto de desenvolvimento, um modelo com menos atributos tende a ter menor propensão de abusar de ruídos e relações espúrias do seu conjunto de treinamento, o que pode levá-lo a ganhos de performance fora do laboratório. Uma seleção bem feita de variáveis é, portanto, uma ferramenta <em>data-centric</em> importante na modelagem de problemas de aprendizado de máquina supervisionado.</div></p>

<p><div align="justify"><i>$\oint$ Para ilustrar a afirmação anterior, temos, como exemplo, que a <a href="https://youtu.be/Dc0sr0kdBVI">dimensão-VC</a> (medida de complexidade de uma família de hipóteses) de um perceptron (classificador linear) é $d+1$, em que $d$ é o número de variáveis utilizadas no modelo [<a href="#bibliography">1</a>]. Um modelo com dimensão-VC maior significa que você precisa de um volume de dados maior para garantir que sua performance, medida no treinamento, seja semelhante à performance real. Na prática, isso significa que quanto maior a dimensão-VC, maior a chance de overfitting. Consequentemente, nesse exemplo, se temos dois perceptrons com performances semelhantes no treino, com a diferença de que um tem mais variáveis que o outro, o que tem mais variáveis tem maior chance de apresentar overfitting [<a href="#bibliography">1</a>].</i></div></p>

<p><div align="justify">Entretanto, a seleção de variáveis não é vista com o cuidado devido na maioria dos cursos de Aprendizado de Máquina. São apresentados poucos métodos e de maneira superficial. Os poucos lugares que discutem o tema, no geral, focam ainda em técnicas que são pouco escaláveis com o aumento de variáveis e, por isso, são pouco praticáveis na maioria das aplicações do mercado (como as estratégias gulosas de <a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SequentialFeatureSelector.html"><code>sklearn.feature_selection.SequentialFeatureSelector</code></a>).</div></p>

<p><div align="justify">No <a href="https://br.linkedin.com/showcase/serasa-experian-datalab">DataLab da Serasa Experian</a>, seleção de variáveis se torna extremamente relevante pela natureza dos problemas que trabalhamos. Na grande maioria dos casos temos algumas milhares de variáveis disponíveis no bureau de dados da Serasa e não é fácil identificar de antemão quais serão as features que nos darão mais ganhos. É necessário aplicar técnicas que são robustas à grandeza do número de variáveis que temos ao mesmo tempo que garantam uma seleção que faça sentido.</div></p>

<p><div align="justify">Neste post, iremos motivar a construção do Boruta [<a href="#bibliography">2</a>], uma das técnicas mais utilizadas pelos cientistas do <a href="https://br.linkedin.com/showcase/serasa-experian-datalab">DataLab</a> na seleção de features, com algumas dicas de uso prático. Ilustraremos ainda o uso da função <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>, do ambiente <a href="https://github.com/scikit-learn-contrib/scikit-learn-contrib/blob/master/README.md">scikit-learn-contrib</a> (ou seja, compatível com bibliotecas que seguem o <a href="https://scikit-learn.org/stable/developers/develop.html">padrão de código do scikit-learn</a>).</div></p>

<hr />

<p><div align="justify">Para ilustrar o problema de seleção de features, utilizaremos o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html"><code>sklearn.datasets.make_classification</code></a> para criar um problema genérico de classificação em que podemos definir, como um parâmetro da função, o número de variáveis úteis para o problema de previsão.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">make_classification</span>

<span class="n">N_FEATURES</span> <span class="o">=</span> <span class="mi">20</span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> \
<span class="n">make_classification</span><span class="p">(</span><span class="n">n_samples</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
                    <span class="n">n_features</span><span class="o">=</span><span class="n">N_FEATURES</span><span class="p">,</span>
                    <span class="n">n_informative</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
                    <span class="n">n_redundant</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
                    <span class="n">n_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
                    <span class="n">flip_y</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
                    <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
                    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>

<span class="n">X</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="sa">f</span><span class="s">'column_</span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">'</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N_FEATURES</span><span class="p">)])</span>

<span class="n">X</span><span class="p">.</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>column_1</th>
      <th>column_2</th>
      <th>column_3</th>
      <th>...</th>
      <th>column_18</th>
      <th>column_19</th>
      <th>column_20</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>-1.050478</td>
      <td>-1.323568</td>
      <td>0.912474</td>
      <td>...</td>
      <td>1.238946</td>
      <td>0.209659</td>
      <td>-0.491636</td>
    </tr>
    <tr>
      <th>1</th>
      <td>-1.580834</td>
      <td>-2.747104</td>
      <td>1.777419</td>
      <td>...</td>
      <td>0.152355</td>
      <td>-0.822420</td>
      <td>1.121031</td>
    </tr>
    <tr>
      <th>2</th>
      <td>-0.885704</td>
      <td>-0.614600</td>
      <td>0.501004</td>
      <td>...</td>
      <td>0.193590</td>
      <td>0.850898</td>
      <td>-0.137372</td>
    </tr>
    <tr>
      <th>3</th>
      <td>-1.525438</td>
      <td>-2.967793</td>
      <td>1.884777</td>
      <td>...</td>
      <td>-0.316073</td>
      <td>0.615771</td>
      <td>1.203884</td>
    </tr>
    <tr>
      <th>4</th>
      <td>-1.076826</td>
      <td>-1.014619</td>
      <td>0.752233</td>
      <td>...</td>
      <td>0.300474</td>
      <td>0.622207</td>
      <td>-1.138833</td>
    </tr>
  </tbody>
</table>
<p>5 rows × 20 columns</p>
</div>

<p><div align="justify">Como estamos escolhendo 2 features informativas e 2 features redundantes, temos que as 4 features mais importantes são as colunas: <code>column_1</code>, <code>column_2</code>, <code>column_3</code> e <code>column_4</code>.</div></p>

<h1 id="motivando-a-construção-do-boruta">Motivando a construção do Boruta</h1>

<h2 id="medindo-a-importância-de-uma-variável">Medindo a importância de uma variável</h2>

<p><div align="justify">Uma das técnicas mais comuns para selecionar as variáveis é aproveitar-se de modelos que, de alguma forma, selecionam-nas no próprio processo de treinamento. Árvores e, consequentemente, comitês de árvores são, talvez, o melhor exemplo disso: pela <a href="https://www.edureka.co/community/46109/what-is-greedy-approach-in-decision-tree-algorithm">estratégia gulosa de fazer a melhor quebra possível naquele instante</a> (de acordo com algum critério de melhor, usualmente relacionado à pureza das folhas criadas, no caso de classificação), estamos sempre escolhendo variáveis relevantes. Variáveis pouco discriminativas são utilizadas muito menos que as variáveis que de fato ajudam a fazer a previsão [<a href="#bibliography">3</a>].</div></p>

<p><div align="justify">Esse processo, naturalmente deriva medidas de importância para as variáveis como: o número de vezes que ela é utilizada (esse é o modo default do atributo <code>.feature_importance_</code> dos ensembles do LGBM, como o <a href="https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html"><code>lightgbm.LGBMClassifier</code></a>) ou uma ponderação do ganho de informação durante a escolha das quebras das features (essa é a forma default dos ensembles de árvores do sklearn, como o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a>, o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html"><code>sklearn.ensemble.ExtraTreesClassifier</code></a>, e o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingClassifier.html"><code>sklearn.ensemble.HistGradientBoostingClassifier</code></a>, além de também virar o atributo do LGBM quando definimos o <code>importance_type=&#39;gain&#39;</code>).</div></p>

<p><div align="justify">Com alguma dessas medidas naturais de importância, é razoável ordenar nossas variáveis da mais importante para a menos importante.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.ensemble</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span>

<span class="n">rfc</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>

<span class="n">df_imp</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">columns</span><span class="p">,</span> <span class="n">rfc</span><span class="p">.</span><span class="n">feature_importances_</span><span class="p">)),</span>
              <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature_name'</span><span class="p">,</span> <span class="s">'feature_importance'</span><span class="p">])</span>
 <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'feature_importance'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
 <span class="p">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">)</span>

<span class="n">df_imp</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>feature_name</th>
      <th>feature_importance</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>column_2</td>
      <td>0.278748</td>
    </tr>
    <tr>
      <th>1</th>
      <td>column_3</td>
      <td>0.201150</td>
    </tr>
    <tr>
      <th>2</th>
      <td>column_4</td>
      <td>0.092612</td>
    </tr>
    <tr>
      <th>3</th>
      <td>column_1</td>
      <td>0.085144</td>
    </tr>
    <tr>
      <th>...</th>
      <td>...</td>
      <td>...</td>
    </tr>
    <tr>
      <th>16</th>
      <td>column_5</td>
      <td>0.018714</td>
    </tr>
    <tr>
      <th>17</th>
      <td>column_16</td>
      <td>0.018641</td>
    </tr>
    <tr>
      <th>18</th>
      <td>column_18</td>
      <td>0.017565</td>
    </tr>
    <tr>
      <th>19</th>
      <td>column_20</td>
      <td>0.016912</td>
    </tr>
  </tbody>
</table>
<p>20 rows × 2 columns</p>
</div>

<p><div align="justify"><i>$\oint$ Existem algumas outras formas de metrificar a importância de uma variável como, por exemplo, utilizando suas contribuições de <a href="https://towardsdatascience.com/shap-explained-the-way-i-wish-someone-explained-it-to-me-ab81cc69ef30">valores SHAP</a>. Tendo em vista que o <a href="https://github.com/slundberg/shap"><code>shap.Explainer(model).shap_values(X)</code></a> nos retorna uma medida de quanto aquela variável agregou na previsão, pegar a sua média entre todos os exemplos nos dá uma forma de ver o quão útil ela foi para discriminar os exemplos como um todo. Para os valores não se cancelarem (imagine uma variável que para determinados valores joga a previsão para cima e em outros valores joga a previsão para baixo), tomamos o módulo antes de fazer a média. Repare que a ordem das importâncias dada pelo SHAP pode ser diferente da ordem de importâncias dada pelo atributo de <code>.feature_importance_</code> usual do estimador, como é o caso do nosso exemplo.</i></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">explainer</span> <span class="o">=</span> <span class="n">shap</span><span class="p">.</span><span class="n">TreeExplainer</span><span class="p">(</span><span class="n">rfc</span><span class="p">)</span>
<span class="n">shap_vals</span> <span class="o">=</span> <span class="n">explainer</span><span class="p">.</span><span class="n">shap_values</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>

<span class="n">df_imp_shap</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">columns</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">shap_vals</span><span class="p">[</span><span class="mi">0</span><span class="p">]).</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))),</span>
              <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature_name'</span><span class="p">,</span> <span class="s">'shap_importance'</span><span class="p">])</span>
 <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'shap_importance'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
 <span class="p">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">)</span>

<span class="n">df_imp_shap</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>feature_name</th>
      <th>shap_importance</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>column_2</td>
      <td>0.197645</td>
    </tr>
    <tr>
      <th>1</th>
      <td>column_3</td>
      <td>0.107211</td>
    </tr>
    <tr>
      <th>2</th>
      <td>column_4</td>
      <td>0.043797</td>
    </tr>
    <tr>
      <th>3</th>
      <td>column_1</td>
      <td>0.041570</td>
    </tr>
    <tr>
      <th>...</th>
      <td>...</td>
      <td>...</td>
    </tr>
    <tr>
      <th>16</th>
      <td>column_18</td>
      <td>0.005851</td>
    </tr>
    <tr>
      <th>17</th>
      <td>column_16</td>
      <td>0.005268</td>
    </tr>
    <tr>
      <th>18</th>
      <td>column_5</td>
      <td>0.005099</td>
    </tr>
    <tr>
      <th>19</th>
      <td>column_20</td>
      <td>0.005019</td>
    </tr>
  </tbody>
</table>
<p>20 rows × 2 columns</p>
</div>

<p><div align="justify"><i>Ainda não falamos do Boruta, mas ele se utiliza dessa ordenação para fazer suas análises e é implementado, usualmente, utilizando medida de importância do estimador (o atributo <code>.feature_importances_</code> ou <code>.coef_</code> para algoritmos lineares). Essa diferença motivou alguns contribuidores a implementar o <a href="https://github.com/Ekeany/Boruta-Shap">Boruta-Shap</a>. Entretanto, incorporar o SHAP ao processo do Boruta não parece trivial e a biblioteca costuma ser lenta.</i></div></p>

<p><div align="justify"><i>Uma possível alternativa pode ser adaptar na mão o atributo <code>.feature_importance_</code> do seu classificador, salvando o <code>X</code> no momento de treinamento para utilização no cálculo do SHAP. Como implemento aqui:</i></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SHAPImportanceRandomForestClassifier</span><span class="p">(</span><span class="n">RandomForestClassifier</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">X_</span> <span class="o">=</span> <span class="n">X</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span>
    <span class="o">@</span><span class="nb">property</span>
    <span class="k">def</span> <span class="nf">feature_importances_</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">check_is_fitted</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
        <span class="n">explainer</span> <span class="o">=</span> <span class="n">shap</span><span class="p">.</span><span class="n">TreeExplainer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
        <span class="n">shap_vals</span> <span class="o">=</span> <span class="n">explainer</span><span class="p">.</span><span class="n">shap_values</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">X_</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">shap_vals</span><span class="p">[</span><span class="mi">0</span><span class="p">]).</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">shap_feature_importances_</span> <span class="kn">import</span> <span class="n">SHAPImportanceRandomForestClassifier</span>

<span class="n">rfc_shap</span> <span class="o">=</span> <span class="n">SHAPImportanceRandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">rfc_shap</span><span class="p">.</span><span class="n">feature_importances_</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array([0.04156985, 0.19764501, 0.10721142, 0.04379691, 0.00509938,
       0.00967927, 0.00900892, 0.00769202, 0.01053711, 0.00973848,
       0.00764462, 0.00725161, 0.00690175, 0.00718789, 0.00600269,
       0.00526766, 0.00659648, 0.00585107, 0.00726538, 0.00501896])
</code></pre></div></div>

<p><div align="justify"><i>Note que essa implementação utiliza o mesmo conjunto de treino para cálculo do SHAP. Existe algum debate aqui, mas tenha em mente que os valores de importância calculados com SHAP (média do valor absoluto) no teste podem ser diferentes dos valores de importância calculados com SHAP no treino. Se você quiser esse nível de preciosismo, pode estar interessado em reservar um pedaço do seu conjunto de dados para calcular os valores SHAP. Implemento essa ideia na classe <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py"><code>XSHAPImportanceRandomForestClassifier</code></a> do arquivo <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py"><code>shap_feature_importances_.py</code></a> no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/boruta">repositório deste post</a>. Entretanto, para poder dormir tranquilo, tenha em mente que o <code>.feature_importances_</code> usual dos algoritmos baseados em árvore é calculado com o conjunto de treino, então calcular o SHAP no treino não é uma blasfêmia tão grande.</i></div></p>

<h2 id="selecionando-as-k-melhores-variáveis">Selecionando as <code class="language-plaintext highlighter-rouge">K</code> “melhores variáveis”</h2>

<p><div align="justify">Se quisermos que nosso modelo tenha apenas as <code>K</code> features mais úteis, a maneira natural de escolhê-las seria pegar as <code>K</code> variáveis com maiores valores de importância.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">K</span> <span class="o">=</span> <span class="mi">4</span>

<span class="p">(</span><span class="n">df_imp</span>
 <span class="p">.</span><span class="n">head</span><span class="p">(</span><span class="n">K</span><span class="p">)</span>
 <span class="p">.</span><span class="n">feature_name</span>
 <span class="p">.</span><span class="n">to_list</span><span class="p">()</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>['column_2', 'column_3', 'column_4', 'column_1']
</code></pre></div></div>

<p><div align="justify">Essa é uma das estratégias mais comuns de se fazer seleção de features no mercado, mas levanta algumas questões. A primeira e mais imediata é: como escolher o número de variáveis <code>K</code> ideal. Nesse caso ilustrativo, sabemos que 4 variáveis é o número correto, mas na maioria dos casos de aplicação real é irrealista ter esse número de antemão.</div></p>

<p><div align="justify"><i>$\oint$ Uma estratégia muito utilizada, mas que não vamos focar muito, é aumentar a lista de features do modelo seguindo a ordenação dada pelo modelo treinado em todas as features, encarando esse valor <code>K</code> como um hiper-parâmetro que estamos otimizando. No exemplo abaixo, fazemos isso utilizando o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a> ao construir uma classe <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> utilizando o padrão necessário para os selecionadores de variáveis do scikit-learn, isto é, seguindo a forma que o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectorMixin.html"><code>sklearn.feature_selection.SelectorMixin</code></a> exige. Você pode ver a implementação dessa classe no arquivo <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>selectktop_selector.py</code></a> no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/boruta">repositório deste post</a>.</i></div></p>

<p><div align="justify"><i>PS: A classe <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> é mais ou menos equivalente à classe <a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFromModel.html"><code>sklearn.feature_selection.SelectFromModel</code></a>, cuja existência descobri após terminar de escrever o post!</i></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">selectktop_selector</span> <span class="kn">import</span> <span class="n">SelectKTop</span>

<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span><span class="p">,</span> <span class="n">RepeatedStratifiedKFold</span>
<span class="kn">from</span> <span class="nn">sklearn.pipeline</span> <span class="kn">import</span> <span class="n">make_pipeline</span>

<span class="n">grid</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">GridSearchCV</span><span class="p">(</span>
        <span class="n">make_pipeline</span><span class="p">(</span><span class="n">SelectKTop</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">),</span>
                      <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)),</span>
        <span class="n">param_grid</span><span class="o">=</span><span class="p">{</span><span class="s">'selectktop__K'</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="n">N_FEATURES</span><span class="o">+</span><span class="mi">1</span><span class="p">)},</span>
        <span class="n">cv</span><span class="o">=</span><span class="n">RepeatedStratifiedKFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">n_repeats</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">),</span>
        <span class="n">scoring</span><span class="o">=</span><span class="s">'roc_auc'</span><span class="p">)</span>
    <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>

<span class="n">df_cv</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">grid</span><span class="p">.</span><span class="n">cv_results_</span><span class="p">)[[</span>
        <span class="s">'param_selectktop__K'</span><span class="p">,</span>
        <span class="s">'mean_test_score'</span><span class="p">,</span>
        <span class="s">'std_test_score'</span>
    <span class="p">]])</span>

<span class="n">cv_best</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">df_cv</span>
    <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'mean_test_score'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
    <span class="p">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="p">.</span><span class="n">loc</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">errorbar</span><span class="p">(</span><span class="n">df_cv</span><span class="p">.</span><span class="n">param_selectktop__K</span><span class="p">,</span> <span class="n">df_cv</span><span class="p">.</span><span class="n">mean_test_score</span><span class="p">,</span> <span class="mf">1.96</span><span class="o">*</span><span class="n">df_cv</span><span class="p">.</span><span class="n">std_test_score</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">cv_best</span><span class="p">.</span><span class="n">param_selectktop__K</span><span class="p">,</span> <span class="n">cv_best</span><span class="p">.</span><span class="n">mean_test_score</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mf">0.75</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'K of SelectKTop'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">df_cv</span><span class="p">.</span><span class="n">param_selectktop__K</span><span class="p">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Performance (ROCAUC)'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/boruta/output_15_0.png" /></center></div></p>

<p><div align="justify"><i>No nosso experimento controlado, encontramos algumas poucas variáveis a mais do que o correto (e ficamos com todas as úteis).</i></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">grid</span><span class="p">.</span><span class="n">best_estimator_</span><span class="p">.</span><span class="n">steps</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">].</span><span class="n">get_feature_names_out</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array(['column_1', 'column_2', 'column_3', 'column_4', 'column_10'],
      dtype=object)
</code></pre></div></div>

<p><div align="justify"><i>Vale citar que podemos deixar esse método mais robusto variando o <code>random_state</code> do <code>base_estimator</code> e tendo uma distribuição de importâncias para cada variável ao invés de apenas um valor único (que naturalmente é mais ruidoso). Utilizar essa técnica com o SHAP para medir a importância (passando por exemplo o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py"><code>SHAPImportanceRandomForestClassifier</code></a> como <code>base_estimator</code> do <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a>) é algo muito utilizado por alguns cientistas do <a href="https://br.linkedin.com/showcase/serasa-experian-datalab">DataLab</a> como alternativa ao Boruta que, como vamos ver, costuma ser muito demorado.</i></div></p>

<h2 id="selecionando-as-k-melhores-variáveis-com-ponto-de-corte-sugerido-por-uma-variável-aleatória">Selecionando as K melhores variáveis com ponto de corte sugerido por uma variável aleatória</h2>

<p><div align="justify">Criar uma variável de ruído, ou seja, que sabidamente não é útil para a previsão, nos auxilia a ter um ponto de corte para filtro das variáveis que demonstram ajudar na previsão. A ideia dessa abordagem é medir a importância da variável aleatória e ficar apenas com variáveis que se demonstrarem mais importantes do que ela.</div></p>

<p><div align="justify">Adicionando a nova coluna, por exemplo, amostrada de uma variável aleatória $\mathcal{N}(0,1)$ de forma independente, temos uma nova lista de importância das variáveis.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">normal_noise_X</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">noise_column</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="mi">42</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])))</span>
<span class="n">normal_noise_X</span><span class="p">[</span><span class="n">normal_noise_X</span><span class="p">.</span><span class="n">columns</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]].</span><span class="n">head</span><span class="p">()</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>noise_column</th>
      <th>column_20</th>
      <th>column_19</th>
      <th>...</th>
      <th>column_3</th>
      <th>column_2</th>
      <th>column_1</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>0.496714</td>
      <td>-0.491636</td>
      <td>0.209659</td>
      <td>...</td>
      <td>0.912474</td>
      <td>-1.323568</td>
      <td>-1.050478</td>
    </tr>
    <tr>
      <th>1</th>
      <td>-0.138264</td>
      <td>1.121031</td>
      <td>-0.822420</td>
      <td>...</td>
      <td>1.777419</td>
      <td>-2.747104</td>
      <td>-1.580834</td>
    </tr>
    <tr>
      <th>2</th>
      <td>0.647689</td>
      <td>-0.137372</td>
      <td>0.850898</td>
      <td>...</td>
      <td>0.501004</td>
      <td>-0.614600</td>
      <td>-0.885704</td>
    </tr>
    <tr>
      <th>3</th>
      <td>1.523030</td>
      <td>1.203884</td>
      <td>0.615771</td>
      <td>...</td>
      <td>1.884777</td>
      <td>-2.967793</td>
      <td>-1.525438</td>
    </tr>
    <tr>
      <th>4</th>
      <td>-0.234153</td>
      <td>-1.138833</td>
      <td>0.622207</td>
      <td>...</td>
      <td>0.752233</td>
      <td>-1.014619</td>
      <td>-1.076826</td>
    </tr>
  </tbody>
</table>
<p>5 rows × 21 columns</p>
</div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">normal_noise_rfc</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">normal_noise_X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>

<span class="n">df_imp_normal_noise</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">normal_noise_X</span><span class="p">.</span><span class="n">columns</span><span class="p">,</span> <span class="n">normal_noise_rfc</span><span class="p">.</span><span class="n">feature_importances_</span><span class="p">)),</span>
              <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature_name'</span><span class="p">,</span> <span class="s">'feature_importance'</span><span class="p">])</span>
 <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'feature_importance'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="p">)</span>

<span class="n">df_imp_normal_noise</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>feature_name</th>
      <th>feature_importance</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>1</th>
      <td>column_2</td>
      <td>0.266446</td>
    </tr>
    <tr>
      <th>2</th>
      <td>column_3</td>
      <td>0.205667</td>
    </tr>
    <tr>
      <th>3</th>
      <td>column_4</td>
      <td>0.087548</td>
    </tr>
    <tr>
      <th>0</th>
      <td>column_1</td>
      <td>0.084593</td>
    </tr>
    <tr>
      <th>...</th>
      <td>...</td>
      <td>...</td>
    </tr>
    <tr>
      <th>8</th>
      <td>column_9</td>
      <td>0.019112</td>
    </tr>
    <tr>
      <th>4</th>
      <td>column_5</td>
      <td>0.018706</td>
    </tr>
    <tr>
      <th>18</th>
      <td>column_19</td>
      <td>0.018264</td>
    </tr>
    <tr>
      <th>19</th>
      <td>column_20</td>
      <td>0.017692</td>
    </tr>
  </tbody>
</table>
<p>21 rows × 2 columns</p>
</div>

<p><div align="justify">Como a última variável é a nossa coluna sabidamente ruidosa, a ideia dessa técnica é selecionar apenas as variáveis que têm importância maior do que o limiar definido pela importância da variável não relacionada.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">normal_noise_importance</span> <span class="o">=</span> \
<span class="n">normal_noise_rfc</span><span class="p">.</span><span class="n">feature_importances_</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

<span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
 <span class="n">df_imp_normal_noise</span>
 <span class="p">.</span><span class="n">query</span><span class="p">(</span><span class="sa">f</span><span class="s">"feature_importance &gt; </span><span class="si">{</span><span class="n">normal_noise_importance</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
 <span class="p">.</span><span class="n">feature_name</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array(['column_2', 'column_3', 'column_4', 'column_1', 'column_6',
       'column_10', 'column_14'], dtype=object)
</code></pre></div></div>

<p><div align="justify">Vale observar que, a escolha da variável ruidosa como $\mathcal{N}(0,1)$ foi totalmente arbitrária. Entretanto, isso faz diferença e pode fazer com que a seleção de variáveis seja distinta. No nosso exemplo controlado, mudar o ruído para $\textrm{Exp}(1)$ nos faria selecionar variáveis finais diferentes totalmente por sorte.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">exp_noise_X</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">assign</span><span class="p">(</span><span class="n">noise_column</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="mi">42</span><span class="p">).</span><span class="n">exponential</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])))</span>
<span class="n">exp_noise_rfc</span> <span class="o">=</span> \
<span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">exp_noise_X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">exp_noise_importance</span> <span class="o">=</span> \
<span class="n">exp_noise_rfc</span><span class="p">.</span><span class="n">feature_importances_</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

<span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span>
 <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">exp_noise_X</span><span class="p">.</span><span class="n">columns</span><span class="p">,</span> <span class="n">exp_noise_rfc</span><span class="p">.</span><span class="n">feature_importances_</span><span class="p">)),</span>
              <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature_name'</span><span class="p">,</span> <span class="s">'feature_importance'</span><span class="p">])</span>
 <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s">'feature_importance'</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
 <span class="p">.</span><span class="n">query</span><span class="p">(</span><span class="sa">f</span><span class="s">"feature_importance &gt; </span><span class="si">{</span><span class="n">exp_noise_importance</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
 <span class="p">.</span><span class="n">feature_name</span>
<span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array(['column_2', 'column_3', 'column_4', 'column_1', 'column_14',
       'column_6', 'column_10', 'column_9', 'column_12', 'column_13',
       'column_7', 'column_18'], dtype=object)
</code></pre></div></div>

<p><div align="justify">Isso nos demonstra um problema desse método. Apesar de poderoso, por nos dar um jeito interessante de selecionar as variáveis sem escolher <code>K</code> de forma arbitrária, a escolha da distribuição da variável ruidosa é uma fonte de variação relevante.</div></p>

<p><div align="justify">Em muitos casos, ter variáveis discretas versus contínuas pode influenciar na medida de importância (como é o caso de árvores que, por terem mais quebras disponíveis, terão mais chance de escolher uma variável ruidosa contínua) ou, ainda, a própria escala da feature adicionada pode atrapalhar nessa mensuração (por exemplo, se estamos usando os coeficientes angulares de um <a href="https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html"><code>sklearn.linear_model.Lasso</code></a>).</div></p>

<p><div align="justify">Toda essa variabilidade pode fazer com que, às vezes, uma feature ruim seja selecionada, ao passo que uma variável boa seja descartada por azar.</div></p>

<p><div align="justify">O Boruta vem para tentar lidar com essas duas questões ao mesmo tempo: tentar manter as distribuições marginais das features ruidosas iguais às distribuições marginais das features originais, enquanto tenta ser robusto à variabilidade, repetindo o experimento algumas vezes.</div></p>

<h1 id="ideias-gerais-do-boruta">Ideias gerais do Boruta</h1>

<p><div align="justify">Já existem muitos textos úteis que explicam o Boruta de forma didática e com exemplos. Como a ideia desse post não é ser redundante com a literatura e sim compilar ideias centrais de uso prático, vamos apenas citar os principais aspectos e deixar o convite para uma leitura detalhada de outras referências do tema como o post <a href="https://towardsdatascience.com/boruta-explained-the-way-i-wish-someone-explained-it-to-me-4489d70e154a">Boruta Explained Exactly How You Wished Someone Explained to You</a>. A construção que fizemos anteriormente vai deixar as ideias do Boruta ainda mais claras, justificando o seu modo de ser.</div></p>

<p><div align="justify">Em resumo, o Boruta [<a href="#bibliography">2,4</a>]:</div></p>
<ul>
  <li>
    <p><div align="justify">Cria variáveis não correlacionadas com a <em>target</em> ao embaralhar, entre as linhas, variáveis já presentes no dataset (essas são as variáveis que chamamos de <em>shadow</em>).</div></p>
  </li>
  <li>
    <p><div align="justify">Lida com a variabilidade repetindo o processo várias vezes e marcando quantas vezes a nossa variável de interesse ficou atrás do percentil <code>perc</code> dos <code>.feature_importances_</code> das <em>shadow features</em> (por default <code>perc=100</code>, portanto, comparamos com o máximo dos <code>.feature_importances_</code> das <em>shadow features</em>, isto é, se alguma <em>shadow</em> for melhor, já descartamos aquela variável de interesse naquela rodada).</div></p>
  </li>
  <li>
    <p><div align="justify">Por fim, um teste de hipótese é feito para avaliar se podemos afirmar com alguma significância estatística <code>alpha</code> que a feature de interesse é melhor que o percentil <code>perc</code> da importância das <em>shadow features</em>.</div></p>
  </li>
  <li>
    <p><div align="justify">O teste de hipótese divide o conjunto de features em três categorias:</div></p>
    <ul>
      <li>
        <p><div align="justify">As variáveis que estatisticamente são variáveis melhores que as <em>shadow features</em> (são as chamadas de <code>.support_</code>);</div></p>
      </li>
      <li>
        <p><div align="justify">As variáveis que estatisticamente são equivalentes às variáveis <em>shadow</em> (variáveis que excluímos);</div></p>
      </li>
      <li>
        <p><div align="justify">As variáveis que não são possíveis de afirmar com significância estatística como sendo melhores que as variáveis <em>shadow</em> (<code>.support_weak_</code>).</div></p>
      </li>
    </ul>
  </li>
  <li>
    <p><div align="justify">Na prática, a partir do momento que ele tem confiança de que uma determinada variável não é importante, ele já a exclui das próximas iterações.</div></p>
  </li>
</ul>

<h1 id="o-borutaborutapy">O <a href="https://github.com/scikit-learn-contrib/boruta_py">boruta.BorutaPy</a></h1>

<p><div align="justify">Primeiro, precisamos instanciar um <code>base_estimator</code> que será utilizado dentro do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> para calcular a importância das variáveis (através do <code>.feature_importances_</code> ou do <code>.coef_</code>). É importante ressaltar que podemos adicionar hiper-parâmetros que acharmos relevantes para o problema, como o <code>class_weight</code> se temos um problema muito desbalanceado.</div></p>

<p><div align="justify">Quando usamos um comitê de árvores, é importante ter em mente que árvores profundas vão mudar o <code>.feature_importances_</code>, mas vão demorar mais para treinar. É justificável utilizar árvores mais rasas, uma vez que os ganhos mais expressivos são feitos nas primeiras quebras, usualmente.</div></p>

<p><div align="justify">O <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> aceita qualquer estimador que tenha o atributo <code>.feature_importances_</code> disponível após rodar o método <code>.fit(X, y)</code> [<a href="#bibliography">5</a>]. Você pode utilizar isso a seu favor usando os estimadores mais adequados para o seu problema, inclusive, utilizando algoritmos baseados em árvores mais eficientes como as <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html"><code>sklearn.ensemble.ExtraTreesClassifier</code></a> (tenha em mente que as Extra Randomized Trees vão ter seu <code>.feature_importances_</code> afetado pelo método de construção e isso pode impactar a escolha final de variáveis).</div></p>

<p><div align="justify">Para exemplificar a utilização prática da biblioteca, vou utilizar o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py"><code>SHAPImportanceRandomForestClassifier</code></a> que criamos anteriormente (basicamente um <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"><code>sklearn.ensemble.RandomForestClassifier</code></a> com SHAP no lugar do <code>.feature_importances_</code> usual).</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">boruta</span> <span class="kn">import</span> <span class="n">BorutaPy</span>

<span class="n">boruta_forest</span> <span class="o">=</span> <span class="n">SHAPImportanceRandomForestClassifier</span><span class="p">(</span><span class="n">max_depth</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Um ponto de atenção que não é necessariamente claro na documentação, é que o parâmetro <code>n_estimators</code> do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> sobrescreve o <code>n_estimators</code> do estimador como podemos ver no <a href="https://github.com/scikit-learn-contrib/boruta_py/blob/3cf4de864e83ad0c50e0cfa177b2bc2aa4735256/boruta/boruta_py.py#L268">código fonte do BorutaPy</a>:</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># set n_estimators
</span><span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_estimators</span> <span class="o">!=</span> <span class="s">'auto'</span><span class="p">:</span>
    <span class="bp">self</span><span class="p">.</span><span class="n">estimator</span><span class="p">.</span><span class="n">set_params</span><span class="p">(</span><span class="n">n_estimators</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">n_estimators</span><span class="p">)</span>
</code></pre></div></div>
<p><div align="justify">Por default, temos <code>n_estimators=1000</code>. Se <code>n_estimators=&#39;auto&#39;</code>, então <a href="https://github.com/scikit-learn-contrib/boruta_py/blob/3cf4de864e83ad0c50e0cfa177b2bc2aa4735256/boruta/boruta_py.py#L371">uma regra baseada no número de features que estamos avaliando é feita para escolher o número de árvores do ensemble</a>.</div></p>

<p><div align="justify">Por fim, <code>alpha</code> e <code>perc</code> são os outros parâmetros importantes do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> que você deveria ficar atento:</div></p>
<ul>
  <li>
    <p><div align="justify">O <code>perc</code> (percentil do <code>.feature_importances_</code> das <em>shadow features</em> utilizado para decidir se as variáveis foram boas ou não naquela determinada rodada) é um <code>int</code> que vai de 0 a 100. Quanto mais próximo de 100, mais rigoroso estamos sendo na hora de avaliar nossas features. Pela aleatoriedade, alguns <code>.feature_importances_</code> de <em>shadow features</em> podem ser grandes e muito rigorosos com o critério de corte, nesse caso, isso será ruim porque estaremos excluindo variáveis marginais que são relevantes, mas não têm uma importância tão expressiva. O default desse parâmetro é 100, mas recomendo abaixá-lo levemente (para 90, por exemplo) caso esteja trabalhando com um problema com muitas variáveis, desse modo haverá maior chance de se ter uma <em>shadow feature</em> com importância alta.</div></p>
  </li>
  <li>
    <p><div align="justify">O <code>alpha</code> é um float que vai de 0 a 1 e é importante para delimitar a partição que fazemos do conjunto de variáveis (<code>.support_weak_</code>, <code>.support_</code> e excluídas), uma vez que determinará o rigor de certeza que queremos ter para afirmar que uma determinada feature é relevante ou não para o problema de classificação (ou regressão). O default desse parâmetro é 0.05, e eu não tenho o costume de alterá-lo, pois prefiro mantê-lo fixo e variar o <code>perc</code> já que os dois se relacionam.</div></p>
  </li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">boruta</span> <span class="o">=</span> \
<span class="p">(</span><span class="n">BorutaPy</span><span class="p">(</span>
    <span class="n">estimator</span><span class="o">=</span><span class="n">boruta_forest</span><span class="p">,</span>
    <span class="n">n_estimators</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span>
    <span class="n">max_iter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="c1"># number of trials to perform
</span>    <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
 <span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> <span class="c1"># fit accepts np.array, not pd.DataFrame
</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Por fim, é fácil resgatar as features com os atributos <code>.support_</code> e <code>.support_weak_</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">green_area</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">boruta</span><span class="p">.</span><span class="n">support_</span><span class="p">].</span><span class="n">to_list</span><span class="p">()</span>
<span class="n">blue_area</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">columns</span><span class="p">[</span><span class="n">boruta</span><span class="p">.</span><span class="n">support_weak_</span><span class="p">].</span><span class="n">to_list</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="s">'Support columns:'</span><span class="p">,</span> <span class="n">green_area</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">'Weak support columns:'</span><span class="p">,</span> <span class="n">blue_area</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Support columns: ['column_1', 'column_2', 'column_3', 'column_4', 'column_10']
Weak support columns: ['column_9']
</code></pre></div></div>

<h1 id="trade-off-de-qualidade-da-seleção-vs-tempo-quando-damos-um-undersample">Trade-off de “qualidade da seleção” vs “tempo” quando damos um undersample</h1>

<p><div align="justify">Quando temos um dataset muito grande, o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> pode demorar bastante tempo para rodar pelo processo de criar tantas variáveis <em>shadows</em> quanto temos no conjunto inicial de variáveis. Em muitas aplicações práticas é necessário aplicar o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> em um subconjunto do seu conjunto de treinamento.</div></p>

<p><div align="justify">Faremos aqui um experimento para ver, em um caso sintético de <code>make_classification</code> com <code>n_samples=5000</code>, <code>n_features=100</code>, <code>n_informative=40</code> e <code>n_redundant=10</code>, como seriam as escolhas de variáveis de um <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> conforme variamos o parâmetro <code>frac</code> de um <a href="https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.sample.html"><code>.sample</code></a> feito na base de desenvolvimento.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">boruta_sample_experiment</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">plot_heatmap</span><span class="p">,</span> <span class="n">plot_percentage_time</span>

<span class="n">dic_sample</span><span class="p">,</span> <span class="n">matrix</span><span class="p">,</span> <span class="n">X_big</span><span class="p">,</span> <span class="n">y_big</span> <span class="o">=</span> \
<span class="n">experiment</span><span class="p">(</span><span class="n">fracs</span><span class="o">=</span><span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>100%|██████████| 11/11 [14:08&lt;00:00, 77.18s/it]
</code></pre></div></div>

<p><div align="justify">Como o número de variáveis informativas mais o número de variáveis redundantes é 50 então, nesse exemplo controlado, metade das nossas features são importantes. No plot abaixo, para diferentes valores de <code>frac</code> (fração dos exemplos da base usada para treinar o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>) vemos quais variáveis estão sendo escolhidas. Idealmente, o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> deveria conseguir identificar que as primeiras 50 variáveis (eixo x) são as úteis e selecioná-las (pintando-as de verde), enquanto exclui as 50 demais (pintando-as de azul), haja vista que são ruído. Conforme variamos o <code>frac</code> (eixo y), vemos como ele se comporta.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_heatmap</span><span class="p">(</span><span class="n">dic_sample</span><span class="p">,</span> <span class="n">matrix</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/boruta/output_38_0.png" /></center></div></p>

<p><div align="justify">Na primeira figura abaixo, vemos uma sumarização do plot anterior variando o <code>frac</code> (eixo x), enquanto observamos a porcentagem das variáveis úteis (em verde) e inúteis (em laranja) que são escolhidas. No gráfico ao lado, há uma análise de tempo (de treinamento do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>) e performance do modelo treinado com as variáveis escolhidas naquele valor de <code>frac</code>.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plot_percentage_time</span><span class="p">(</span><span class="n">dic_sample</span><span class="p">,</span> <span class="n">matrix</span><span class="p">,</span> <span class="n">X_big</span><span class="p">,</span> <span class="n">y_big</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/boruta/output_40_0.png" /></center></div></p>

<p><div align="justify">Como podemos ver, não precisamos de todas as amostras para treinar o nosso <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>. No exemplo anterior, apesar de a nossa amostra ter 5000 elementos, com algo em torno de 3000 exemplos, já era possível encontrar perfeitamente todas as 50 variáveis úteis para o nosso problema.</div></p>

<p><div align="justify">Na minha experiência utilizando o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>, me sinto confortável com _uma amostra com 15 vezes mais exemplos do que features (ou seja, <code>n_samples&gt;=15*n_features</code>)_. Nesse limiar, já costumo ter resultados bons em termos de seleção de variáveis e é possível rodar o algoritmo (em tempo satisfatório para desenvolvimento) com um <code>max_depth</code> controlado. Colocando um exemplo numérico: se, no <a href="https://br.linkedin.com/showcase/serasa-experian-datalab">DataLab</a>, estou trabalhando com um problema de 5 mil variáveis, me sinto confortável em rodar o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> em uma amostra de 75 mil linhas, mesmo tendo muito mais exemplos na base de desenvolvimento.</div></p>

<p><div align="justify">Por outro lado, o exemplo anterior nos mostra que nem sempre isso é o melhor, mesmo em questão de tempo. O <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a>, na prática, não vai rodar por <code>max_iter</code> se já tiver certeza (no nível de significância <code>alpha</code>) das variáveis que ele acha úteis para o problema, que ele já exclui (ou seleciona) no meio do caminho. No experimento anterior, ter mais exemplos, na verdade, fez com que o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> ficasse com mais certeza de forma mais rápida sobre as variáveis. Na prática, isso dificilmente acontece.</div></p>

<h1 id="usando-o-boruta-na-prática-e-algumas-alternativas">Usando o Boruta na prática e algumas alternativas</h1>

<p><div align="justify">As ideias por trás do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> são muito interessantes, mas o algoritmo final é temporalmente custoso. Por sorte, podemos utilizar as ideias da construção para fazer variações espertas que podem ser alternativas se uma rodada inicial (com <code>max_depth ~ 10</code>, <code>perc=90</code> e <code>n_estimators=500</code>) estiver demorando demais:</div></p>
<ol>
  <li>
    <p><div align="justify">Utilizar o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> com alguma métrica de <code>.feature_importances_</code> mais robusta (como o SHAP, usando algo como nosso <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/shap_feature_importances_.py"><code>SHAPImportanceRandomForestClassifier</code></a>) e tendo cuidado com a escolha do <code>K</code>;</div></p>
  </li>
  <li>
    <p><div align="justify">Adaptar o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> que construímos para um versão ainda mais robusta que lida com uma distribuição de <code>.feature_importances_</code> ao invés de apenas um estimador (aliás, esse é um ótimo exercício para o leitor interessado em entender melhor a <a href="https://scikit-learn.org/stable/developers/develop.html">API do scikit-learn</a>);</div></p>
  </li>
  <li>
    <p><div align="justify">Adaptar o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> para um &quot;<code>SelectAboveNoise</code>&quot;, que explicamos anteriormente, criando as variáveis aleatórias a partir do <a href="https://numpy.org/doc/stable/reference/random/index.html"><code>numpy.random</code></a> (outro exercício muito bom);</div></p>
  </li>
  <li>
    <p><div align="justify">Utilizar o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> com algoritmos mais rápidos (como <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html"><code>sklearn.ensemble.ExtraTreesClassifier</code></a>), mas lembrando que seu treinamento (ainda mais randomizado) vai afetar o <code>.feature_importances_</code> e, consequentemente, o resultado final.</div></p>
  </li>
  <li>
    <p><div align="justify">Reduzir a amostra utilizada para treino do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> respeitando a <em>rule of thumb</em> de <code>n_samples&gt;=15*n_features</code>.</div></p>
  </li>
  <li>
    <p><div align="justify">Mexer mais estruturalmente no algoritmo de forma que ele crie menos variáveis <em>shadows</em> em problemas com muitas variáveis (<em>to be tested</em>).</div></p>
  </li>
</ol>

<p><div align="justify">Se o seu problema é razoavelmente pequeno, usar o <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> com o SHAP e otimizar os hiper-parâmetros do <a href="https://github.com/scikit-learn-contrib/boruta_py"><code>boruta.BorutaPy</code></a> é uma boa opção. Para isso, será útil utilizar o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/boruta_selector.py"><code>Boruta</code></a> que criei no arquivo <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/boruta_selector.py"><code>boruta_selector.py</code></a> no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/boruta">repositório deste post</a>. Ele já está no formato adequado de <a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectorMixin.html"><code>Selector</code></a> do scikit-learn e pode ser utilizado da mesma forma que vimos o <a href="https://github.com/vitaliset/vitaliset.github.io/blob/master/code/boruta/selectktop_selector.py"><code>SelectKTop</code></a> sendo usado (com um pipeline e qualquer <a href="https://github.com/scikit-learn/scikit-learn/blob/36958fb240fbe435673a9e3c52e769f01f36bec0/sklearn/model_selection/_search.py#L372"><code>BaseSearchCV</code></a> do scikit-learn).</div></p>

<h1 id="conclusão">Conclusão</h1>

<p><div align="justify">Seleção de variáveis é um assunto necessário quando queremos garantir ter um modelo robusto. Neste post vimos uma das técnicas mais úteis para abordar esse problema enquanto, ao entender suas ideias, discutimos como adaptá-la para uma variedade de casos específicos. Mesmo que você não consiga usar o Boruta no seu problema em questão, as ideias aqui expostas permitem que você faça uma seleção de variáveis sabendo melhor as falhas e os benefícios de abordagens usuais do mercado.</div></p>

<h1 id="referências"><a name="bibliography">Referências</a></h1>

<p><div align="justify">[1] <a href="https://cs.nyu.edu/~mohri/mlbook/">Foundations of Machine Learning. Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. MIT Press, Second Edition, 2018</a>.</div></p>

<p><div align="justify">[2] <a href="https://www.jstatsoft.org/article/view/v036i11">Feature Selection with the Boruta Package. Miron B. Kursa, Witold R. Rudnicki. Journal of Statistical Software</a>.</div></p>

<p><div align="justify">[3] <a href="https://youtu.be/_L39rN6gz7Y">Decision and Classification Trees, Clearly Explained!!!. Josh Starmer. StatQuest with Josh Starmer</a>.</div></p>

<p><div align="justify">[4] <a href="https://towardsdatascience.com/boruta-explained-the-way-i-wish-someone-explained-it-to-me-4489d70e154a">Boruta Explained Exactly How You Wished Someone Explained to You. Samuele Mazzanti. Towards Data Science</a>.</div></p>

<p><div align="justify">[5] <a href="https://github.com/scikit-learn-contrib/boruta_py">boruta_py README.md documentation. Daniel Homola</a>.</div></p>

<p><div align="justify">Para mais dicas práticas de uso (e com um argumento de autoridade muito melhor que o meu), o autor do Boruta tem o guia <a href="https://cran.r-project.org/web/packages/Boruta/vignettes/inahurry.pdf">Boruta for those in a hurry</a> que, apesar de estar escrito em R, tem dicas práticas interessantes de alguém que conhece a implementação com muita profundidade.</div></p>

<hr />

<p><div align="justify">Todos os arquivos e ambiente para reprodução dos experimentos podem ser encontrado no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/boruta">repositório deste post</a>.</div></p>

<p><div align="justify">Este post foi originalmente publicado no <a href="https://medium.com/datalab-log">Medium do Experian DataLab</a>! Passe no <a href="https://medium.com/datalab-log/sele%C3%A7%C3%A3o-de-vari%C3%A1veis-uma-utiliza%C3%A7%C3%A3o-cr%C3%ADtica-do-boruta-f3e974238f56">post</a> e deixe uma palminha, se achar que faz sentido! :D</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇧🇷&quot;, &quot;feature selection&quot;]" /><summary type="html"><![CDATA[Um passo a passo da utilização do Boruta, discutindo a motivação da construção e debatendo variações úteis do algoritmo.]]></summary></entry><entry><title type="html">Covariate Shift: Classificador Binário</title><link href="https://vitaliset.github.io/covariate-shift-2-classificador-binario/" rel="alternate" type="text/html" title="Covariate Shift: Classificador Binário" /><published>2020-08-30T00:00:00+00:00</published><updated>2020-08-30T00:00:00+00:00</updated><id>https://vitaliset.github.io/covariate-shift-2-classificador-binario</id><content type="html" xml:base="https://vitaliset.github.io/covariate-shift-2-classificador-binario/"><![CDATA[<p><div align="justify">Este post faz parte de uma série de postagens que discutem o problema de <i>covariate shift</i>. Assumo que você já conhece a motivação do problema e no que estamos interessados em identificar e corrigir. Se você ainda não leu o <a href="https://vitaliset.github.io/covariate-shift-0-introduction/">primeiro post</a> dessa série, sugiro a leitura.</div></p>

<p><div align="justify">Agora, vamos focar em identificar o <i>covariate shift</i> na distribuição conjunta. Desta forma, o problema fica enunciado como: dados $X$ e $Z$ vetores aleatórios e dois conjuntos de observações amostrados de forma independente $\{x_1, x_2, \cdots, x_N \} $ e $\{z_1, z_2, \cdots, z_M \} $, queremos entender se a distribuição conjunta é a mesma, isto é se $X\sim Z$, estudando apenas as amostras coletadas. No contexto do <i>dataset shift</i>, em que estamos particularmente interessados, o vetor aleatório $X$ indica a distribuição das covariáveis no conjunto de treino e o vetor aleatório $Z$ nos revela a distribuição das variáveis explicativas dos dados em produção.</div></p>

<p><div align="justify">Anteriormente, no <a href="https://vitaliset.github.io/covariate-shift-1-qqplot/">segundo post</a> da série, discutimos uma técnica para encontrar mudança nas distribuições marginais dos vetores aleatórios, o QQ-plot. Sugerimos ainda uma variação numérica da técnica visual.</div></p>

<p><div align="justify">Agora, vamos utilizar aprendizado de máquina supervisionado para identificar problemas em aprendizado de máquina supervisionado.</div></p>

<h1 id="entendendo-o-problema-de-classificação">Entendendo o problema de classificação</h1>

<p><div align="justify">O problema de classificação binária surge naturalmente nesse cenário. Se temos duas amostras de distribuições possivelmente diferentes, podemos treinar um modelo que tenta identificar se os dados são da distribuição $X$ ou da distribuição $Z$.</div></p>

<p><div align="justify">Se o classificador binário consegue identificar as diferenças, então temos uma variação da distribuição. Se o classificador não consegue, mantendo uma acurácia baixa, então confiamos que a distribuição se manteve parecida.</div></p>

<p><div align="justify">Vamos ilustrar essa técnica nos dados que geraram o desconforto inicial apresentado no final da [postagem anterior](https://vitaliset.github.io/covariate-shift-1-qqplot/). Aqui fica claro que nem sempre analisar apenas as distribuições marginais é suficiente.</div></p>

<p><div align="justify">Explicitamente temos os vetores aleatórios $X= (X_1,X_2)$ e $Z=(Z_1, Z_2)$ tais que</div></p>

\[\begin{equation*}
\begin{pmatrix}X_{1}\\
X_{2}
\end{pmatrix} \sim  \mathcal{N}
\begin{pmatrix}
\begin{bmatrix}
0\\
0
\end{bmatrix} ,
\begin{bmatrix}
1 &amp; 0.75 \\
0.75 &amp; 1 
\end{bmatrix}
\end{pmatrix} \textrm{, e }\begin{pmatrix}Z_{1}\\
Z_{2}
\end{pmatrix} \sim  \mathcal{N}
\begin{pmatrix}
\begin{bmatrix}
0\\
0
\end{bmatrix} ,
\begin{bmatrix}
1 &amp; -0.75 \\
-0.75 &amp; 1 
\end{bmatrix}
\end{pmatrix} .
\end{equation*}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">multivariate_normal</span><span class="p">(</span><span class="n">mean</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">],</span> <span class="n">cov</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">1</span><span class="p">,</span><span class="n">t</span><span class="o">*</span><span class="mf">0.75</span><span class="p">],</span> <span class="p">[</span><span class="n">t</span><span class="o">*</span><span class="mf">0.75</span><span class="p">,</span><span class="mi">1</span><span class="p">]],</span> <span class="n">size</span> <span class="o">=</span> <span class="n">n</span><span class="p">).</span><span class="n">T</span>

<span class="n">X1</span><span class="p">,</span> <span class="n">X2</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">Z1</span><span class="p">,</span> <span class="n">Z2</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<center><img src="/assets/img/covariate_2_classificador_binario/imagem1.jpg" /></center>
<center><b>Figura 1</b>: amostras das distribuições $X$ ($s=0$) e $Z$ ($s=1$), com correlações opostas entre as coordenadas.</center>

<p><div align="justify">A ideia é simples, vamos organizar nossos dados criando uma nova coluna que nos diz se o dado é da distribuição $X$ ($s=0$) ou da distribuição $Z$ ($s=1$).</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">'variavel_1'</span><span class="p">:</span><span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">X1</span><span class="p">,</span><span class="n">Z1</span><span class="p">]),</span> <span class="s">'variavel_2'</span><span class="p">:</span><span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">X2</span><span class="p">,</span><span class="n">Z2</span><span class="p">]),</span> <span class="s">'s'</span><span class="p">:[</span><span class="mi">0</span><span class="p">]</span><span class="o">*</span><span class="n">X1</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">+</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="n">Z1</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]})</span>

<span class="n">X_miss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">df</span><span class="p">.</span><span class="n">drop</span><span class="p">([</span><span class="s">'s'</span><span class="p">],</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">S_miss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s">'s'</span><span class="p">])</span>
</code></pre></div></div>

<table>
  <thead>
    <tr>
      <th style="text-align: center">variável 1</th>
      <th style="text-align: center">variável 2</th>
      <th style="text-align: center">y</th>
      <th style="text-align: center">s</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="text-align: center">0.178105</td>
      <td style="text-align: center">0.651739</td>
      <td style="text-align: center">$y_1$</td>
      <td style="text-align: center">0</td>
    </tr>
    <tr>
      <td style="text-align: center">0.464192</td>
      <td style="text-align: center">-0.461877</td>
      <td style="text-align: center">$y_2$</td>
      <td style="text-align: center">0</td>
    </tr>
    <tr>
      <td style="text-align: center">1.0948</td>
      <td style="text-align: center">0.823703</td>
      <td style="text-align: center">$y_3$</td>
      <td style="text-align: center">0</td>
    </tr>
    <tr>
      <td style="text-align: center">…</td>
      <td style="text-align: center">…</td>
      <td style="text-align: center">…</td>
      <td style="text-align: center">…</td>
    </tr>
    <tr>
      <td style="text-align: center">0.393783</td>
      <td style="text-align: center">-0.681826</td>
      <td style="text-align: center">?</td>
      <td style="text-align: center">1</td>
    </tr>
    <tr>
      <td style="text-align: center">0.623834</td>
      <td style="text-align: center">-0.344885</td>
      <td style="text-align: center">?</td>
      <td style="text-align: center">1</td>
    </tr>
    <tr>
      <td style="text-align: center">-0.800357</td>
      <td style="text-align: center">0.444416</td>
      <td style="text-align: center">?</td>
      <td style="text-align: center">1</td>
    </tr>
  </tbody>
</table>

<p><div align="justify">Aqui, já fazendo um panorama com a realidade em que estamos aplicando esse modelo, coloquei uma coluna para a variável target $y$ que seria a variável alvo do nosso problema inicial. Não vamos usá-la em nenhum momento na identificação do <i>covariate shift</i>, o que é esperado já que não temos os targets dos dados novos encontrados na produção.</div></p>

<p><div align="justify">Com essa estrutura construída, a ideia é simples. Criamos um classificador que utiliza as variáveis 1 e 2 para prever $s$. Se o seu resultado em um conjunto de teste é ruim, então os dados de $X$ e $Z$ são indistinguíveis e concluímos que $X\sim Z$. Agora, se o nosso classificador tem boas métricas, então quer dizer que as distribuições diferem.</div></p>

<h1 id="construindo-e-avaliando-o-classificador-binário">Construindo e avaliando o classificador binário</h1>

<p><div align="justify">Primeiro, separamos nossos dados em 2 conjuntos. Um para treino e outro para teste.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X_miss_train</span><span class="p">,</span> <span class="n">X_miss_test</span><span class="p">,</span> <span class="n">S_miss_train</span><span class="p">,</span> <span class="n">S_miss_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X_miss</span><span class="p">,</span> <span class="n">S_miss</span><span class="p">,</span> <span class="n">test_size</span> <span class="o">=</span> <span class="mf">0.8</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Agora podemos utilizar um classificador binário qualquer. Como estou começando a me apaixonar pelo Vapnik, vou utilizar uma Support Vector Machine. Os hiper-parâmetros "default" das SVM costumam fazer um bom trabalho, mas em um mundo ideal, podemos fazer uma pequena otimização dos hiper-parâmetros maximizando a métrica <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html"><code>roc_auc_score</code></a>. </div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">param</span> <span class="o">=</span> <span class="p">{</span><span class="s">'C'</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">geomspace</span><span class="p">(</span><span class="mf">0.01</span><span class="p">,</span><span class="mi">100</span><span class="p">,</span><span class="mi">13</span><span class="p">),</span> <span class="s">'gamma'</span><span class="p">:</span> <span class="p">[</span><span class="s">'scale'</span><span class="p">]</span><span class="o">+</span><span class="nb">list</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">geomspace</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span><span class="mi">100</span><span class="p">,</span><span class="mi">10</span><span class="p">)),</span> <span class="s">'kernel'</span><span class="p">:</span> <span class="p">[</span><span class="s">'rbf'</span><span class="p">]}</span>
<span class="n">grid_search</span> <span class="o">=</span> <span class="n">GridSearchCV</span><span class="p">(</span><span class="n">SVC</span><span class="p">(</span><span class="n">probability</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span> <span class="n">param</span><span class="p">,</span> <span class="n">cv</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">scoring</span><span class="o">=</span> <span class="p">[</span><span class="s">'roc_auc'</span><span class="p">,</span><span class="s">'accuracy'</span><span class="p">],</span> <span class="n">refit</span> <span class="o">=</span> <span class="s">'roc_auc'</span><span class="p">,</span> <span class="n">return_train_score</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">grid_search</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_miss_train</span><span class="p">,</span> <span class="n">S_miss_train</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Em seguida, utilizamos o modelo encontrado em todos os dados e podemos avaliar seu desempenho. </div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">svm</span> <span class="o">=</span> <span class="n">SVC</span><span class="p">(</span><span class="n">probability</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="o">**</span><span class="n">grid_search</span><span class="p">.</span><span class="n">best_params_</span><span class="p">)</span>
<span class="n">svm</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_miss_train</span><span class="p">,</span><span class="n">S_miss_train</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="s">'acuracia: '</span><span class="p">,</span><span class="n">accuracy_score</span><span class="p">(</span><span class="n">S_miss_test</span><span class="p">,</span><span class="n">svm</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_miss_test</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="s">'roc_auc: '</span><span class="p">,</span><span class="n">roc_auc_score</span><span class="p">(</span><span class="n">S_miss_test</span><span class="p">,</span><span class="n">svm</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_miss_test</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="s">'phi coeficiente: '</span><span class="p">,</span><span class="n">matthews_corrcoef</span><span class="p">(</span><span class="n">S_miss_test</span><span class="p">,</span><span class="n">svm</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_miss_test</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>acuracia:  0.72625
roc_auc:  0.7268930344332967
phi coeficiente:  0.4827618287310226
</code></pre></div></div>

<p><div align="justify">Não temos uma acurácia estado da arte, mas claramente nosso modelo identificou um padrão e consegue discriminar dados como sendo de uma distribuição ou de outra.</div></p>

<p><div align="justify">$\oint$ <i> Uma métrica não tão clássica, mas muito útil é o <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html">coeficiente de correlação de Matthews</a>. Inspirado no coeficiente de correlação de Pearson, queremos entender correlação para atributos categóricos. Isso deu origem ao coeficiente phi de Pearson, a ideia dele é generalizar o coeficiente de correlação entre a nossa previsão e os valores reais da target binária. É uma forma numérica de avaliar a matriz de confusão. Seu cálculo é feito como</i></div></p>

\[\begin{equation*}
\textrm{MCC} = \frac{T_p \, T_n - F_p \, F_n}{\sqrt{(T_p+F_p)(T_p+F_n)(T_n+F_p)(T_n+F_n)}},
\end{equation*}\]

<p><div align="justify"><i>em que $T_p$ é o número de verdadeiros positivos, $T_n$, a quantidade de verdadeiros negativos, $F_p$ o número de falsos positivos e $F_n$ o número de falsos negativos. Apesar de parecer um pouco confuso, analisando o numerador vemos que estamos multiplicando os valores corretamente classificados e subtraindo a multiplicação dos incorretamente classificados. O denominador serve como uma normalização deixando o resultado entre $-1$ e $1$, em que $1$ significa uma previsão perfeita, $0$ uma previsão aleatória e $-1$ uma previsão trocada.</i></div></p>

<p><div align="justify">No nosso caso ilustrativo em duas dimensões, podemos fazer as curvas de nível do <code>predict_proba</code> do SVM e visualizar que ele entendeu as regiões mais prováveis de cada uma das distribuições.</div></p>

<center><img src="/assets/img/covariate_2_classificador_binario/imagem2.jpg" /></center>
<center><b>Figura 2</b>: curvas de nível do <code>predict_proba</code> do SVM, mostrando as regiões mais prováveis de cada distribuição.</center>

<p><div align="justify">$\oint$ <i>O SVM não nos dá naturalmente o <code>predict_proba</code>, precisamos passar <code>probability=True</code> na sua inicialização. O <code>sklearn</code> aplica a <a href="https://www.cs.colorado.edu/~mozer/Teaching/syllabi/6622/papers/Platt1999.pdf">abordagem de Platt</a> utilizando uma <a href="https://scikit-learn.org/stable/modules/svm.html#scores-probabilities">regressão logística no score do SVM</a>. Essa técnica pode ser utilizada com classificadores quaisquer, para melhorar a <a href="https://scikit-learn.org/stable/modules/calibration.html#calibration">calibração de probabilidade</a>. Inclusive é uma <a href="https://gdmarmerola.github.io/probability-calibration/">técnica útil para ensembles de árvores</a>.</i></div></p>

<h1 id="entendendo-a-mudança-na-distribuição-a-partir-do-classificador">Entendendo a mudança na distribuição a partir do classificador</h1>

<p><div align="justify">Agora precisamos avaliar se as distribuições são diferentes ou não. Podemos analisar um histograma dos <code>predict_proba</code> aplicado nas duas amostras separadamente como vemos na Figura 3. Claramente, nosso SVM identifica regiões em que a chance de ser de uma das distribuições é maior do que ser de outra. O fato de ele nos dar tanta certeza é um indicativo de que ele consegue distinguir bem.</div></p>

<center><img src="/assets/img/covariate_2_classificador_binario/imagem3.jpg" /></center>
<center><b>Figura 3</b>: histograma do <code>predict_proba</code> do SVM aplicado separadamente às amostras de $X$ e de $Z$.</center>

<p><div align="justify">Supondo que confiamos na medida de probabilidade que ele nos dá. Uma métrica um pouco arbitrária é olhar qual a porcentagem dos dados está na região entre $[0, x) \cup (0.5+x,1]$ para $0\leq x\lt 0.5$. Por exemplo, podemos olhar a proporção de exemplos com <code>predict_proba</code> de $0$ a $25\%$ ou de $75\%$ a $100\%$. Estes são os dados que o classificador julga como "fáceis de classificar" por estarem em regiões dominadas por alguma das classes.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="mf">0.25</span>
<span class="p">((</span><span class="n">svm</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_miss</span><span class="p">)[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">&lt;</span><span class="mf">0.5</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span><span class="n">svm</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_miss</span><span class="p">)[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">&gt;</span><span class="mf">0.5</span><span class="o">+</span><span class="n">x</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span><span class="o">/</span><span class="n">X_miss</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.4605
</code></pre></div></div>

<p><div align="justify">Quase metade dos dados estão nas regiões "fáceis" de acordo com essa análise de probabilidade. Claro que isso não é perfeito pela existência de outliers, mas é um indicativo claro de que existem regiões do espaço de atributos favorecidas por uma das distribuições e regiões do espaço favorecidas pela outra distribuição. Fixado $x$, podemos escolher um valor $\varepsilon\in(0,1]$ tal que: se a proporção de dados nas regiões "fáceis" for maior que $\varepsilon$ então temos um alerta de que há uma mudança na distribuição.</div></p>

<p><div align="justify">Podemos tentar criar também thresholds de acurácia ou do coeficiente phi que indicam que há uma mudança na distribuição ou não. Isso não é necessariamente claro também e podemos monitorar com rigor demais ou ser muito brandos.</div></p>

<p><div align="justify">Como discutido no post anterior, esses thresholds universais não existem. O que vale é analisar nos seus dados históricos casos de <i>covariate shift</i> que você sabe que aconteceram e analisar se existiria um $\varepsilon$ que teria funcionado neles.</div></p>

<h1 id="caso-sem-mudança">Caso sem mudança</h1>

<p><div align="justify">Vale estudar como essa metodologia se comportaria em casos em que não há mudança na distribuição. Por exemplo, se ambas as distribuições fossem geradas pela mesma normal multivariada dada por</div></p>

\[\begin{equation*}
\begin{pmatrix}X_{1}\\
X_{2}
\end{pmatrix},
\begin{pmatrix}Z_{1}\\
Z_{2}
\end{pmatrix} \sim  \mathcal{N}
\begin{pmatrix}
\begin{bmatrix}
0\\
0
\end{bmatrix} ,
\begin{bmatrix}
1 &amp; 0.75 \\
0.75 &amp; 1 
\end{bmatrix}
\end{pmatrix}.
\end{equation*}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X1</span><span class="p">,</span> <span class="n">X2</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">Z1</span><span class="p">,</span> <span class="n">Z2</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">Fazendo exatamente os mesmos procedimentos que anteriormente, temos agora curvas de nível muito mais confusas como vemos na Figura 4. O classificador tenta se adaptar um pouco às particularidades das amostras, mas não se atreve a dar probabilidades altas para nenhuma das regiões justamente porque nenhuma das regiões é privilegiada por uma das distribuições neste caso.</div></p>

<center><img src="/assets/img/covariate_2_classificador_binario/imagem4.jpg" /></center>
<center><b>Figura 4</b>: curvas de nível do <code>predict_proba</code> quando as duas amostras vêm da mesma distribuição.</center>

<p><div align="justify">Isso fica ainda mais claro quando olhamos para as métricas de classificação neste caso. Fica claro que as distribuições são indistinguíveis nesse caso, como esperado.</div></p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>acuracia:  0.511875
roc_auc:  0.5115329746824565
phi coeficiente:  0.023459774068708163
</code></pre></div></div>

<p><div align="justify">A análise da distribuição dos <code>predict_proba</code> também conversa com o que esperávamos. Agora, o modelo é muito mais conservador, colocando as probabilidades próximas de $0.5$ como vemos na Figura 5.  </div></p>

<center><img src="/assets/img/covariate_2_classificador_binario/imagem5.jpg" /></center>
<center><b>Figura 5</b>: histograma do <code>predict_proba</code> no caso sem mudança de distribuição, concentrado em torno de $0.5$.</center>

<p><div align="justify">Neste caso, os <code>predict_proba</code> estão concentrados entre $0.4$ e $0.6$, como esperado. O modelo é conservador e não encontra regiões fáceis de classificação.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="mf">0.1</span>
<span class="p">((</span><span class="n">svm</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_miss</span><span class="p">)[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">&lt;</span><span class="mf">0.5</span><span class="o">-</span><span class="n">x</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span><span class="n">svm</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_miss</span><span class="p">)[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">&gt;</span><span class="mf">0.5</span><span class="o">+</span><span class="n">x</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span><span class="o">/</span><span class="n">X_miss</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.0
</code></pre></div></div>

<h1 id="pontos-de-atenção-e-considerações-finais">Pontos de atenção e considerações finais</h1>

<p><div align="justify">Assim como a maioria das técnicas de monitoramento, não é necessariamente claro identificar se há ou não o <i>covariate shift</i> categoricamente. A criação de thresholds para alertas é nebulosa. A ideia é sempre utilizar várias formas de avaliar, gerando relatórios que precisam ser olhados de forma crítica.</div></p>

<p><div align="justify">Em muitos casos, toda essa análise com otimização de hiper-parâmetros e utilizando modelos custosos como o SVM pode ser inviável. Não precisamos ter um classificador binário estado da arte, ele só precisa ser bom o suficiente para conseguir aprender a identificar as regiões de cada uma das amostras (se existir) dando probabilidades adequadas. Logo, fique à vontade para escolher o classificador que você mais gostar, com o cuidado na hora das análises do <code>predict_proba</code>. Como comentei anteriormente, os parâmetros default das SVM costumam ser razoáveis e você pode sempre pegar algumas sub-amostras dos dados para fazer essas análises.</div></p>

<p><div align="justify">É razoável se preocupar também com o balanceamento entre o tamanho dos dados de treino ($s=0$) e dados de produção ($s=1$) para ser razoável analisar acurácia e métricas simples. Novamente, lembrando que esse classificador não precisa ser perfeito, um <i>undersample</i> da classe dominante me parece suficiente.</div></p>

<p><div align="justify">Essa técnica incorporada em linhas de produção robustas pode ser uma forma inteligente de identificação de variação das covariáveis de treino e produção. No próximo post utilizaremos o princípio da minimização do erro empírico de Vapnik para discutir porque o <i>covariate shift</i> se torna um problema. Essa narrativa nos indicará uma maneira elegante de amenizar os problemas causados pelo <i>covariate shift</i> quando o retreino com dados mais parecidos com os da produção não é possível.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇧🇷&quot;, &quot;dataset shift&quot;]" /><summary type="html"><![CDATA[Uma técnica para identificar mudanças de distribuição a partir de métricas de um classificador binário.]]></summary></entry><entry><title type="html">Covariate Shift: QQ-plot</title><link href="https://vitaliset.github.io/covariate-shift-1-qqplot/" rel="alternate" type="text/html" title="Covariate Shift: QQ-plot" /><published>2020-08-16T00:00:00+00:00</published><updated>2020-08-16T00:00:00+00:00</updated><id>https://vitaliset.github.io/covariate-shift-1-qqplot</id><content type="html" xml:base="https://vitaliset.github.io/covariate-shift-1-qqplot/"><![CDATA[<p><div align="justify">Este post faz parte de uma série de postagens que discutem o problema de <i>Covariate Shift</i>. Assumo que você já conhece a motivação do problema e no que estamos interessados em identificar e corrigir. Se você ainda não leu o <a href="https://vitaliset.github.io/covariate-shift-0-introduction/">primeiro post</a> dessa série, sugiro a leitura.</div></p>

<p><div align="justify">Relembrando a reformulação do enunciado do problema, temos $X$ e $Z$ variáveis (ou vetores) aleatórias e dois conjuntos de observações amostrados de forma independente $\{x_1, x_2, \cdots, x_N \} $ e $\{z_1, z_2, \cdots, z_M \} $. Queremos entender se a distribuição das variáveis é a mesma, isto é se $X\sim Z$, estudando apenas as amostras coletadas. No contexto do <i>dataset shift</i>, em que estamos particularmente interessados, o vetor aleatório $X$ indica a distribuição das covariáveis no conjunto de treino e o vetor aleatório $Z$ nos revela a distribuição das variáveis explicativas dos dados em produção.</div></p>

<p><div align="justify">A primeira técnica que vamos discutir é utilizando o QQ-plot (quantil-quantil-plot). Avaliando se os $\alpha$-quantis das duas amostras são parecidos, podemos discutir a validade de assumir $X\sim Z$ ou não. </div></p>

<h1 id="alpha-quantis-de-uma-variável-aleatória">$\alpha$-quantis de uma variável aleatória</h1>

<p><div align="justify">Existem algumas formas diferentes de se calcular $\alpha$-quantis. Elas são mais ou menos equivalentes para as análises que estamos interessados, então não vamos detalhar pequenas variações. Começaremos discutindo um $\alpha$-quantil muito clássico que você já conhece: a mediana.</div></p>

<p><div align="justify">A mediana de um conjunto de dados é o valor real que divide nossos dados em dois subconjuntos de tamanhos iguais: o conjunto dos maiores que a mediana e o conjunto dos menores ou iguais à mediana. Por exemplo, se temos o conjunto $S =\{ 1, 2, 4, 6, 6, 9\}$, então a mediana pode ser $4$ já que ficamos com $|\{x \in S : x\leq 4 \}|$ $ = 3 =$ $ |\{x \in S : x\gt 4 \}|$.</div></p>

<p><div align="justify">O conceito de mediana pode ser estendido para variáveis aleatórias. Nesse caso, estamos interessados em procurar um valor real $p$ tal que a probabilidade da variável aleatória ser menor ou igual a $p$ seja 0.5. Isso significa que o valor $p$ divide a reta em duas regiões $\{ x\in\mathbb{R}:x\leq p \}$ e $\{ x\in\mathbb{R}:x\gt p \}$ com a mesma probabilidade, ou seja, $\mathbb{P}(X\leq p)$ $=0.5=$ $\mathbb{P}(X\gt p)$.</div></p>

<p><div align="justify">Dado $\alpha\in(0,1)$, a ideia de um $\alpha$-quantil de uma variável aleatória $X$ é uma generalização do que fizemos com a mediana. Queremos dividir a reta em duas regiões, uma com probabilidade $\alpha$ e a segunda com uma probabilidade $1-\alpha$. Na mediana, tínhamos $\alpha=0.5$, aqui é feito de forma análoga, mas mais geral. A ideia é que tenhamos que $q_X(\alpha)$, o $\alpha$-quantil de $X$, satisfaça a equação</div></p>

\[\mathbb{P}\left( X\leq q_X(\alpha) \right) = \alpha.\]

<p><div align="justify">Lembrando que $F_X(t) = \mathbb{P}(X\leq t)$ é a função de distribuição acumulada de uma variável aleatória $X$. A $q_X:(0,1)\to\mathbb{R}$, chamada função quantil, seria a inversa de $F_X$. Ou seja, $F_X(q_X(\alpha))=\alpha$. A mediana de uma variável aleatória $X$ é formalmente definida como $q_X(0.5)$.</div></p>

<p><div align="justify">Entretanto, podemos exibir variáveis aleatórias problemáticas tal que a equação não tem solução para alguns valores de $\alpha\in(0,1)$. Por exemplo, pegando $X\sim\textrm{Ber}(0.4)$, então não existe $p\in\mathbb{R}$ tal que $F_X(p ) = 0.5$ uma vez que</div></p>

\[F_X(t) = \begin{cases} 0\textrm{, se }t\lt0, \\
0.6\textrm{, se }0\leq t\lt 1,\\
1\textrm{, se }t\geq1.\end{cases}\]

<p><div align="justify">Dessa forma não conseguimos definir $q_X(0.5)$, a mediana da variável Bernoulli de parâmetro $0.4$ utilizando essa forma para função quantil.</div></p>

<p><div align="justify">Note também que no primeiro exemplo, para a mediana do conjunto $S$, a mediana não está unicamente determinada. Poderíamos ter pego a mediana como sendo $5$, já que este valor também dividiria nossos dados em conjuntos do mesmo tamanho.</div></p>

<p><div align="justify">Como queremos uma função bem definida, uma solução para esses problemas é fazer a <b>função quantil</b> tal que</div></p>

\[\begin{equation*}
q_X(\alpha) = \min \{t \in \mathbb{R} : \mathbb{P}(X\leq t) = F_X(t) \geq \alpha \}.
\end{equation*}\]

<p><div align="justify">Neste caso, o valor $q_X(\alpha)$ é o menor valor real tal que a probabilidade acumulada é pelo menos $\alpha$. No caso discutido para $X\sim\textrm{Ber}(0.4)$, agora temos que $q_X(0.5) = 0$ já que 0 é o menor valor real que faz $F_X$ ser maior ou igual a $0.5$. E a mediana do conjunto $S$ fica unicamente definida uma vez que $4$ é o menor valor que satisfaz a divisão em dois conjuntos iguais.</div></p>

<p><div align="justify">Para variáveis aleatórias $X$ tais que $F_X$ são contínuas, essa forma de definir $q_X(\alpha)$ equivale à primeira tentativa de definição. Esses são os exemplos em que estaremos mais interessados quando analisarmos o QQ-plot.</div></p>

<p><div align="justify">$\oint $ <i>A generalização da inversa que fizemos é particularmente útil quando temos funções monotônicas, mas descontínuas e não necessariamente injetoras como é o caso das funções de distribuição acumulada de variáveis aleatórias discretas. A única alteração que temos que fazer em casos mais gerais é usar $\inf$ ao invés de $\min$ (pelas propriedades da função distribuição acumulada, como temos a continuidade pela direita, essas duas formas são equivalentes). </i></div></p>

<h2 id="cálculo-da-função-quantil-de-uma-variável-aleatória-contínua">Cálculo da função quantil de uma variável aleatória contínua</h2>

<p><div align="justify">Quando $X$ é uma variável aleatória contínua com distribuição de probabilidade $f_X$, temos uma forma explícita de cálculo para $F_X$ como</div></p>

\[\begin{equation*}
    F_X(t) = \int_{-\infty}^t f_X(s) \, ds.
\end{equation*}\]

<p><div align="justify">Dada uma variável aleatória com distribuição exponencial $X\sim \textrm{Exp}(\lambda)$, vamos exibir diretamente $q_X$. Para calcular $F_X$, utilizamos a densidade de probabilidade $f_X$ da forma</div></p>

\[\begin{equation*}
    f_X(s) = \begin{cases}
\lambda e^{-\lambda s}\textrm{, se } s\geq 0\textrm{,}\\
0 \textrm{, caso contrário.}
\end{cases}
\end{equation*}\]

<p><div align="justify">Podemos calcular $F_X$ como</div></p>

\[\begin{equation*}
    F_X(t) = \int_{-\infty}^{t} f_X(s) ds = \int_0^t \lambda e^{-\lambda s} ds = -\,e^{-\lambda s}\, \bigg\rvert_{0}^{t} = 1 - e^{-\lambda t},
\end{equation*}\]

<p><div align="justify">para $t\geq 0$ e $F_X(t)=0$ para $t&lt;0$.</div></p>

<p><div align="justify">Podemos achar uma forma explícita para $q_X(\alpha)$ neste caso. Basta resolver a equação:</div></p>

\[\begin{equation*}
    \alpha = F_X(q_X(\alpha)) = 1 - e^{-\lambda q_X(\alpha)}
\therefore 1- \alpha = e^{-\lambda q_X(\alpha)},
\end{equation*}\]

<p><div align="justify">concluindo que</div></p>

\[\begin{equation*}
    q_X(\alpha) = \frac{-\ln(1-\alpha)}{\lambda}.
\end{equation*}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">dens_exp</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">lamb</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">piecewise</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">s</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">,</span> <span class="n">s</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="k">lambda</span> <span class="n">s</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">s</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">lamb</span><span class="o">*</span><span class="n">s</span><span class="p">)</span><span class="o">/</span><span class="n">lamb</span><span class="p">])</span>

<span class="k">def</span> <span class="nf">quantil_exp</span><span class="p">(</span><span class="n">t</span><span class="p">,</span><span class="n">lamb</span><span class="p">):</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">1</span><span class="o">-</span><span class="n">t</span><span class="p">)</span><span class="o">/</span><span class="n">lamb</span>
</code></pre></div></div>

<p><div align="justify">Por exemplo, se queremos calcular a mediana de $X\sim\textrm{Exp}(\lambda =1)$, fazemos simplesmente $q_X(0.5)=-\ln(0.5)\approx0.693$. Interpretando esse resultado, temos que $\mathbb{P}\left( X\leq -\ln(0.5) \right)=0.5$, logo pintando a área embaixo da curva, como na Figura 1, temos metade da área da densidade de probabilidade até $-\ln(0.5)$.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem1.jpg" /></center>
<center><b>Figura 1</b>: Densidade de probabilidade da variável aleatória exponencial com $\lambda=1$. A sombra representa a área embaixo da curva de 0 até $-\ln(0.5)$, representando metade da probabilidade.</center>

<h2 id="cálculo-da-função-quantil-de-uma-variável-aleatória-discreta">Cálculo da função quantil de uma variável aleatória discreta</h2>

<p><div align="justify">Agora suponha que $X\sim \textrm{Binomial}(2,0.5)$. Então $\mathbb{P}(X=0)=\mathbb{P}(X=2)= 0.25$ e $\mathbb{P}(X=1)=0.5$. Construímos a densidade acumulada como</div></p>

\[F_X(t) = \begin{cases} 0\textrm{, se }t\lt0, \\
0.25\textrm{, se }0\leq t \lt 1,\\
0.75\textrm{, se }1\leq t \lt 2,\\
1\textrm{, se }t\geq2.\end{cases}\]

<p><div align="justify">Para calcular a função quantil, precisamos usar a versão que diz que</div></p>

\[q_X(\alpha) = \min \{t \in \mathbb{R} : F_X(t) \geq \alpha \}.\]

<p><div align="justify">Com isso, temos por exemplo que $q_X(0.9)=2$ uma vez que o menor valor de $F_X(t)$ maior ou igual a $0.9$ é $1$ e ocorre primeiro quando $t=2$. Fazendo esse mesmo tipo de raciocínio para todos os $\alpha \in (0,1)$, chegamos na função quantil como</div></p>

\[q_X(\alpha) = \begin{cases} 0\textrm{, se }0\lt \alpha \leq 0.25, \\
1\textrm{, se }0\lt \alpha \leq 0.75, \\
2\textrm{, se }0.75\leq \alpha \lt 1.\end{cases}\]

<h1 id="qq-plot">QQ-plot</h1>

<p><div align="justify">A ideia do <b>QQ-plot</b> (ou gráfico quantil-quantil) se baseia em uma observação inteligente: se duas variáveis aleatórias $X$ e $Y$ tem distribuições parecidas (isto é, se $F_X \approx F_Y$), então seus $\alpha$-quantis são semelhantes também (ou seja, as funções quantis são próximas $q_X \approx q_Y$). </div></p>

<p><div align="justify">Portanto, se $X$ e $Y$ têm distribuições parecidas, quando plotarmos a "curva parametrizada"</div></p>

\[\begin{equation*}
    \{ (q_X(\alpha), q_Y(\alpha) ) \in \mathbb{R}^2 : \alpha \in (0,1) \},
\end{equation*}\]

<p><div align="justify">esperamos que a curva fique próxima da reta identidade $y=x$ . O nome QQ-plot surge pois estamos plotando os quantis das nossas variáveis aleatórias nos dois eixos.</div></p>

<p><div align="justify">Para visualizar esse plot, vamos ver um exemplo analítico. Sejam $X \sim \textrm{Exp}(\lambda=1)$ e $Y \sim \textrm{Uniforme}([0,1])$. Já calculamos de forma transparente $q_X(\alpha)=-\ln(1-\alpha)$ e é fácil conferir que $q_Y(\alpha) = \alpha$.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">dens_uni</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">piecewise</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">s</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">,</span> <span class="p">(</span><span class="n">s</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">s</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">),</span> <span class="n">s</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> 
    
<span class="k">def</span> <span class="nf">quantil_uni</span><span class="p">(</span><span class="n">t</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">t</span>
</code></pre></div></div>

<p><div align="justify">Como podemos ver na primeira imagem da Figura 2, essas distribuições são próximas no início (perto da origem) e depois ficam qualitativamente bem diferentes. Plotando a curva dada por </div></p>

\[\begin{equation*}
    \{ (-\ln(1-\alpha), \alpha ) \in \mathbb{R}^2 : \alpha \in (0,1) \},
\end{equation*}\]

<p><div align="justify">temos o QQ-plot na segunda imagem da Figura 2.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem2.jpg" /></center>
<center><b>Figura 2</b>: à esquerda, as densidades de $X\sim\textrm{Exp}(1)$ e $Y\sim\textrm{Uniforme}([0,1])$; à direita, o QQ-plot analítico correspondente.</center>

<h2 id="alpha-quantis-para-amostras">$\alpha$-quantis para amostras</h2>

<p><div align="justify">Quando não conhecemos $F_X$, não temos como calcular $q_X(\alpha)$ analiticamente. Mas se temos disponível uma amostra $\left\{x_1,\ldots,x_N \right\}$ independentes e identicamente distribuídas de $X$ de tamanho $N$ podemos estimar os $\alpha$-quantis.</div></p>

<ul>
  <li>
    <p><div align="justify">Primeiro, devemos ordenar a amostra $\left\{x_1,\ldots,x_N \right\}$ de forma crescente renomeando os índices dos exemplos como $\left\{ x_{(1)},\ldots,x_{(N)} \right\}$.</div></p>
  </li>
  <li>
    <p><div align="justify">Com isso, dado $\alpha \in (0,1)$, a estimativa para o $\alpha$-quantil da variável aleatória que gerou a amostra é</div></p>

    <p>\(\begin{equation*}
        \widehat{q}_{X}(\alpha) = x_{( \lfloor N\alpha \rfloor +1)},
    \end{equation*}\)
em que $\lfloor N\alpha \rfloor$ é o menor inteiro menor ou igual a $N\alpha$.</p>
  </li>
</ul>

<p><div align="justify">A ideia dessa forma de estimar o $\alpha$-quantil é que uma fração $\alpha$ da nossa amostra será identificada como os elementos menores ou iguais a $\widehat{q}_X(\alpha)$. Na Figura 3 podemos observar alguns $\alpha$-quantis de uma amostra de dados para $N=40$. Plotando eles na horizontal, ordenados, identificamos o $0.25$-quantil como o décimo elemento da nossa lista, marcado em verde uma vez que $25$ por cento dos nossos dados são menores ou iguais a ele.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem3.jpg" /></center>
<center><b>Figura 3</b>: Uma coleção de dados colocado em ordem crescente e alguns $\alpha$-quantis ilustrativos.</center>

<p><div align="justify">Quando $N\to \infty$ temos que $\widehat{q}_{X}(\alpha) \to q_{X}(\alpha)$ em probabilidade, <a href="https://stats.stackexchange.com/questions/45124/central-limit-theorem-for-sample-medians">pelo menos para variáveis aleatórias contínuas</a>. Isso nos permite acreditar que, para $N$ grande, o $\alpha$-quantil estimado é próximo do $\alpha$-quantil real, vamos usar esse fato para comparar nossas amostras.</div></p>

<h2 id="qq-plot-para-duas-amostras">QQ-plot para duas amostras</h2>

<p><div align="justify">A idéia do QQ-plot é justamente utilizar essa ideia para afirmar que se a amostra $\left\{x_1,\ldots,x_N \right\}$ e a amostra $\left\{y_1,\ldots,y_M \right\}$ vieram de distribuições $X$ e $Y$, respectivamente, parecidas, então também serão parecidas as funções quantis estimadas</div></p>

\[\begin{equation*}
    \widehat{q}_{X}(\alpha) \approx \widehat{q}_{Y}(\alpha).
\end{equation*}\]

<p><div align="justify">Neste caso, se parametrizamos uma curva pelo valor $\alpha$ e plotamos no eixo $x$ o valor $\widehat{q}_{X}$ e no eixo $y$ o valor $\widehat{q}_{Y}$, necessariamente devemos ter pontos próximos da reta identidade $y=x$.</div></p>

<p><div align="justify">Variando o parâmetro da curva com passos iguais, plotamos os pontos</div></p>

\[\begin{equation*}
    \left\{ (\widehat{q}_X(\alpha_i), \widehat{q}_Y(\alpha_i) ) \in \mathbb{R}^2 : \alpha_i = \frac{i}{k} \textrm{, para }i\in\{1,2,\cdots,k-1\} \right\},
\end{equation*}\]

<p><div align="justify">para natural $k \gt 2$. Estamos andando na curva anterior dando passos de tamanho $1/k$ no parâmetro $\alpha$. Por exemplo, para $k=10$, então plotamos os $9$ pontos referentes aos $\alpha_i$-quantis para $\alpha_i$$=0.1$, $0.2$, $\cdots$, $0.8$, $0.9$. Se temos $k=20$, então pegamos os $19$ pontos identificados por $\alpha_i$$=0.05$, $0.1$, $\cdots$, $0.9$, $0.95$.</div></p>

<p><div align="justify">Na Figura 4 temos vários QQ-plots para diferentes escolhas de variáveis aleatórias $X$ e $Y$, tamanhos das amostras $N$ e $M$, e números de pontos do plot $k-1$. </div></p>

<ul>
  <li>
    <p><div align="justify">Na primeira imagem da Figura 4, temos que $X,Y\sim\mathcal{N}(0.5,1)$ com $N,M=200$  e $k=10$. Vemos que os pontos se aproximam da identidade, mas há uma pequena variação porque como a amostra é pequena as estimativas para os $\alpha$-quantis variam bastante.</div></p>
  </li>
  <li>
    <p><div align="justify">Na segunda imagem, temos as mesmas distribuições, mas agora como $N,M=10000$ e $k=25$. Os $\alpha$-quantis estimados são mais precisos e por isso os pontos estão bem em cima da reta identidade. </div></p>
  </li>
  <li>
    <p><div align="justify">Na terceira imagem, temos $X\sim\textrm{Uniforme}([0,1])$ e $Y\sim\mathcal{N}(0,1)$ com $N=2000$, $M=1000$ e $k=25$. Este é um caso em que a média das duas distribuições geradoras é igual (por isso os pontos do meio ficam próximos à identidade), mas conseguimos identificar a diferença das distribuições.</div></p>
  </li>
  <li>
    <p><div align="justify">No caso da quarta imagem, temos $X\sim\mathcal{N}(0,1)$ e $Y\sim\mathcal{N}(1,1)$ com $N,M=3000$ e $k=20$. Como a distribuição é igual a menos da média, podemos perceber que os pontos ficam na reta $y=x+1$ ao invés da identidade.</div></p>
  </li>
  <li>
    <p><div align="justify">A quinta imagem é a versão amostral do QQ-plot que fizemos analiticamente anteriormente na Figura 2, quando temos $X\sim\textrm{Exp}(1)$ e $Y\sim\textrm{Uniforme}([0,1])$. Estamos fazendo $N,M=2000$ e $k=100$.</div></p>
  </li>
  <li>
    <p><div align="justify">Por fim, na última imagem temos um exemplo para comparação da distribuição binomial com a distribuição normal. Fazemos $X\sim\textrm{Binomial}(400,0.5)$ e $Y\sim\mathcal{N}(200,100)$, com $N,M=4000$ e $k=20$.</div></p>

    <p><div align="justify">$\oint$ <i>Para cada $t\in\mathbb{N^*}$, definindo $Z_t\sim\textrm{Binomial}(t,0.5)$, então temos que </i> </div></p>

\[\frac{Z_t- 0.5\, t}{0.5\, \sqrt{t}}\overset{\mathscr{D}}{\to} \mathcal{N}(0,1)\]

    <p><div align="justify"><i>utilizando o teorema do limite central observando que $Z_t\sim\sum_{i=1}^t B_i$ em que $B_i \sim \textrm{Bernoulli}(0.5)$ são independentes.</i></div></p>
  </li>
</ul>

<center><img src="/assets/img/covariate_1_qqplot/imagem4.jpg" /></center>
<center><b>Figura 4</b>: QQ-plots para diferentes escolhas de distribuições $X$ e $Y$, tamanhos de amostra $N$ e $M$, e número de pontos $k-1$, conforme descrito no texto.</center>

<p><div align="justify">O QQ-plot foi construído originalmente para ser uma forma <b>visual</b> de identificar se duas amostras analisadas são de distribuições próximas ou não. A princípio, essa maneira de análise não nos dá uma métrica numérica que podemos estudar.</div></p>

<h2 id="sugestão-de-métrica-quantitativa">Sugestão de métrica quantitativa</h2>

<p><div align="justify">Para obter um um valor numérico para que possamos avaliar se nossas distribuições estão próximas, devemos relembrar qual foi a motivação do QQ-plot: estamos comparando os pontos com a reta identidade. Isso nos leva a pensar em usar uma métrica de regressão do quão boa a reta identidade $f(x)=x$ se adapta aos nossos dados </div></p>

\[\begin{equation*}
    \left\{ (\widehat{q}_X(\alpha_i), \widehat{q}_Y(\alpha_i) ) \in \mathbb{R}^2 : \alpha_i = \frac{i}{k} \textrm{, para }i\in\{1,2,\cdots,k-1\} \right\}.
\end{equation*}\]

<p><div align="justify">Utilizando o $\textrm{MSE}$ ou o $\textrm{MAE}$, por exemplo, ficamos com as expressões:</div></p>

\[\textrm{MSE} = \frac{1}{k-1} \sum_{i=1}^{k-1} (f(\widehat{q}_X(\alpha_i)) - \widehat{q}_Y(\alpha_i))^2 = \frac{1}{k-1} \sum_{i=1}^{k-1} (\widehat{q}_X(\alpha_i) - \widehat{q}_Y(\alpha_i))^2,\]

\[\textrm{MAE} = \frac{1}{k-1}\sum_{i=1}^{k-1} \left|\widehat{q}_X(\alpha_i) - \widehat{q}_Y(\alpha_i)\right|.\]

<p><div align="justify">$\oint$ <i>Gosto da ideia de usar métricas como $\textrm{MSE}$ e $\textrm{MAE}$ pela simetria. Não importaria se trocássemos as amostras $X$ e $Y$ de lugar.</i> </div></p>

<p><div align="justify">Na Figura 5 temos alguns exemplos de QQ-plots e suas respectivas métricas. Estamos usando sempre $N,M=3000$. Na primeira imagem temos $X, Y\sim\mathcal{N}(0,1)$, para $k=10$. Na segunda temos $X\sim\textrm{Uniforme}([0,1])$ e $Y\sim\textrm{Uniforme}([-1,2])$, para $k=25$. Na terceira imagem temos $X\sim\textrm{Uniforme}([0,1])$ e $Y\sim\mathcal{N}(0.5,1)$, com $k=30$. Por fim, temos $X,Y\sim\mathcal{N}(300,400)$, escolhendo $k=20$.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem5.jpg" /></center>
<center><b>Figura 5</b>: QQ-plots e suas respectivas métricas ($\textrm{MSE}$ e $\textrm{MAE}$) para os exemplos descritos no texto.</center>

<p><div align="justify">Como podemos ver, essa forma de cálculo das métricas não soluciona o problema. Dependendo da escala dos nossos dados podemos ter a métrica inflada, mesmo com as amostras vindo da mesma distribuição. Isso ocorre no último QQ-plot da Figura 5.</div></p>

<p><div align="justify">Uma sugestão pra manter os dados não muito maiores que $1$ em módulo é aplicar um <code>StandardScaler</code> nos nossos dados. Calculamos a média e variância amostral da amostra $\{x_1,x_2,\cdots,x_n\}$ e transformamos nossos dados de forma que agora</div></p>

\[\begin{equation*}
    \left\{ x_i^* = \frac{x_i - \widehat{\mu_X}}{S_X}\right\} \textrm{, e também } \left\{ y_i^* = \frac{y_i - \widehat{\mu_X}}{S_X}\right\}.
\end{equation*}\]

<p><div align="justify">É importante notar que não estamos modificando o formato do QQ-plot, apenas deformando e transladando os eixos já que aplicamos o mesmo <i>scaler</i> nos dois eixos. A ideia é que se $X\sim Y$, então o <i>scaler</i> fitado na amostra de $X$ deveria deixar as duas amostras com média $0$ e variância $1$.</div></p>

<p><div align="justify">Na Figura 6 temos o QQ-plot utilizando essa metodologia e suas respectivas métricas. Agora, fixamos que $N,M=3000$ e $k=20$. Na primeira imagem temos $X\sim\textrm{Exp}(1)$ enquanto $Y\sim \mathcal{N}(0,1)$. Na segunda temos $X\sim \mathcal{N}(10,9)$ e $Y\sim\mathcal{N}(5,1)$. Na terceira imagem temos $X\sim \mathcal{N}(11,1)$ e $Y\sim \mathcal{N}(10,1)$. Por fim, na última temos $X,Y\sim \mathcal{N}(300,400)$.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem6.jpg" /></center>
<center><b>Figura 6</b>: QQ-plots com os dados padronizados (StandardScaler) e suas respectivas métricas, para os exemplos descritos no texto.</center>

<p><div align="justify">Com isso, temos maior esperança de ter métricas com valores baixos para amostras de uma mesma distribuição, independentes da escala, como é o caso da última imagem da Figura 6.</div></p>

<p><div align="justify">$\oint$ <i>Um pequeno detalhe é que agora nem sempre temos a métrica simétrica, pois a média e variância da amostra de $Y$ possivelmente é diferente da de $X$.</i> </div></p>

<p><div align="justify">Fixados $N$, $M$ e $k$, o ideal seria definir um $\varepsilon\in \mathbb{R}^+$ universal para criar um critério do tipo: <i>se $ \textrm{MSE}$ (ou $\textrm{MAE}$) $&lt; \varepsilon$, então desconfiamos que $X\sim Y$ e caso contrário, acreditamos que $X\nsim Y$</i>. Entretanto essa tarefa parece impossível e o valor de $\varepsilon$ depende da natureza dos nossos dados e do quanto somos tolerantes com o problema de <i>covariate shift</i>.</div></p>

<p><div align="justify">Para avaliar se essa forma de monitoramento é útil, vale aplicar em alguns dados reais da área que você está analisando. Entender como se comportam as métricas sugeridas ($\textrm{MAE}$ e $\textrm{MSE}$) nos casos em que não há <i>dataset shift</i> e nos casos em que há.</div></p>

<p><div align="justify">Se você não tem muitas versões de tempos diferentes, ou se você não sabe se há ou não <i>covariate shift</i>, vale a pena dividir seus dados de uma mesma base em dois conjuntos disjuntos. Entender como fica a métrica aplicada a essas duas amostras e depois mudar artificialmente a distribuição da segunda somando e multiplicando ruídos aos dados.</div></p>

<h1 id="problemas-e-considerações-finais">Problemas e considerações finais</h1>

<p><div align="justify">O QQ-plot é uma estratégia visual muito útil de verificação de <i>covariate shift</i>. É uma maneira interessante e eficiente de gerar relatórios de acompanhamento de qualidade de bases. Fácil de explicar e de implementar, não sendo muito custoso computacionalmente por apenas precisar ordenar os dados nos cálculos do $\alpha$-quantis. Apesar de suas qualidades, temos alguns problemas importantes.</div></p>

<p><div align="justify">O QQ-plot funciona bem para variáveis aleatórias contínuas. Porém, no geral, para variáveis aleatórias discretas temos funções quantis patológicas, com descontinuidades e as funções quantis estimadas não são muito confiáveis.</div></p>

<p><div align="justify">$\oint$ <i>Imagine o cenário em que $X,Y\sim\textrm{Ber}(0.5)$, então podemos calcular $q_X(0.5)=0$. Mas agora, nas nossas amostras, temos uma com um valor de $0$ a mais e a outra um valor de $1$ a mais. Nesse cenário, as medianas estimadas seriam $0$ e $1$, respectivamente e ganharíamos um ponto completamente distante da nossa reta identidade. Esse problema independe do tamanhos das amostras e pode ocorrer inflando nossa métrica. A falta de continuidade gera esses problemas.</i></div></p>

<p><div align="justify">Além disso, com as variáveis aleatórias contínuas, o QQ-plot peca em não nos dar uma métrica numérica para avaliar em monitoramentos automatizados. A escolha de $\varepsilon$ é arbitrária demais e em muitos casos podemos gerar alertas desnecessários sendo muito rigorosos ou deixar passar casos problemáticos se formos muito tolerantes.</div></p>

<p><div align="justify">Por fim, esse tipo de métrica avalia nossas variáveis aleatórias de forma independente. Em muitos casos, o <i>covariate shift</i> pode ocorrer na distribuição conjunta do vetor aleatório e não perceberemos isso olhando para as distribuições marginais. Um exemplo desse problema pode ser visto na Figura 7.</div></p>

<center><img src="/assets/img/covariate_1_qqplot/imagem7.jpg" /></center>
<center><b>Figura 7</b>: um exemplo em que o covariate shift ocorre na distribuição conjunta, sem ser percebido nas distribuições marginais.</center>

<p><div align="justify">Nos próximos posts dessa série, vamos ver uma outra técnica que pode ajudar nesses casos. No geral, as técnicas de monitoramento de <i>covariate shift</i> tem seus pontos fortes e fracos. O ideal é sempre ter várias formas diferentes para identificar possíveis problemas e fazer intervenções.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇧🇷&quot;, &quot;dataset shift&quot;]" /><summary type="html"><![CDATA[Uma primeira abordagem para identificar mudanças de distribuição sugerindo uma variação numérica para a versão visual.]]></summary></entry><entry><title type="html">Covariate Shift: Introduction</title><link href="https://vitaliset.github.io/covariate-shift-0-introduction/" rel="alternate" type="text/html" title="Covariate Shift: Introduction" /><published>2020-08-02T00:00:00+00:00</published><updated>2020-08-02T00:00:00+00:00</updated><id>https://vitaliset.github.io/covariate-shift-0-introduction</id><content type="html" xml:base="https://vitaliset.github.io/covariate-shift-0-introduction/"><![CDATA[<p><div align="justify">Este texto foi inicialmente redigido em português e posteriormente traduzido. A versão original em português pode ser encontrada no <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/covariate_introduction">repositório de experimentos</a>.</div></p>

<hr />

<p><div align="justify">The primary goal of supervised learning is to identify patterns between independent variables (explanatory variables) and a dependent variable (target variable). In mathematical terms, within a regression context, we have a random vector $V = (X_1, X_2, \cdots, X_n, Y)$ and we suppose that there exists a relationship between the independent variables $X_i$ and the dependent variable $Y$, expressed as:</div></p>

\[\left(Y \,|\, X_1=x_1, X_2=x_2,\cdots, X_n=x_n\right)\sim f(x_1, x_2,\cdots, x_n) + \varepsilon,\]

<p><div align="justify">where $f:\mathbb{R}^n\to \mathbb{R}$ is any given function and $\varepsilon$ is a random variable with mean $0$, referred to as noise (which might also vary depending on the values of $X_i$). The supervised learning approach attempts to estimate the function $f$ using prior observations (a sample of the random vector $V$).</div></p>

<p><div align="justify">$\oint$ <em>Note that our illustration uses regression as an example due to its straightforwardness. Nonetheless, the case of classification isn't significantly more complex. In binary classification, the aim is to estimate $f:\mathbb{R}^n\to [0,1]$ as follows:</em></div></p>

\[\left(Y \,|\, X_1=x_1, X_2=x_2,\cdots, X_n=x_n\right)\sim \textrm{Bernoulli}(p)\textrm{, with }p=f(x_1, x_2,\cdots, x_n).\]

<p><div align="justify">Generally, during cross-validation, we expect that the performance of our estimated function will remain consistent on the validation set when faced with new data. Machine learning in non-stationary environments, however, presents a challenge: What happens if there's a dataset shift, meaning the distribution of the random vector $V$ differs in new data? Can we realistically expect the model to uphold its validated performance?</div></p>

<p><div align="justify">In this context, we encounter two common scenarios [<a href="#bibliography">1</a>]. The first, concept shift, takes place when the function $f$ connecting the variables $X_i$ and $Y$ changes. A seemingly less noticeable but equally alarming issue arises when the relationship between the explanatory and target variables remains constant, but the distribution of variables $X_i$ in new examples deviates from the distribution in the training data. This is known as covariate shift, a situation that we'll learn to identify and offer a potential solution for in this series of posts.</div></p>

<p><div align="justify">But first, let's create an artificial scenario that exhibits covariate shift. This will help illuminate the concepts through a practical situation and explore the problems that may emerge if this shift isn't properly identified and addressed.</div></p>

<hr />

<h2 id="example-of-dataset-shift-between-training-data-and-production-data">Example of dataset shift between training data and production data</h2>

<p><div align="justify">Consider $X$ to be a random variable that follows a normal distribution, $X\sim \mathcal{N}(0,1)$. Let $f:\mathbb{R}\to\mathbb{R}$ be a function defined as $f(x) = \cos(2\pi x)$, and $\varepsilon$ be a noise variable modeled as $\varepsilon \sim \mathcal{N}(0,0.25)$. We will build a dataset generated by this random experiment.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">X</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">X</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">f_ruido</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">random_state</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">f</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state</span><span class="p">).</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>

<span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">rs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">random_state</span><span class="p">).</span><span class="n">randint</span><span class="p">(</span>
        <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">int64</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">2</span>
    <span class="p">)</span>
    <span class="n">X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">rs</span><span class="p">[</span><span class="mi">0</span><span class="p">]).</span><span class="n">normal</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">n</span><span class="p">)</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">f_ruido</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">rs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">X</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">Y</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p><div align="justify">In this example, we will conduct this experiment $100$ times, creating our data with the mean of $X$ at $0$ as previously mentioned.</div></p>

<p><div align="justify">Despite the noise being of the same order of magnitude as $f$, the pattern of the function that drives the generation of the data can still be discerned. Our goal is to make predictions: given new observations of $X=x$, we aim to estimate the corresponding values for $(Y \, | \, X=x)$.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X_past</span><span class="p">,</span> <span class="n">Y_past</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>

<span class="n">x_plot</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">X_past</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">X_past</span><span class="p">),</span> <span class="mi">1000</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_past</span><span class="p">,</span> <span class="n">Y_past</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Sample"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_plot</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x_plot</span><span class="p">),</span> <span class="n">c</span><span class="o">=</span><span class="s">"k"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"f(x)"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/covariate_0_introduction/output_5_0.png" /></center></div></p>

<p><div align="justify">We will employ a simple model for regression, namely the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html"><code>sklearn.tree.DecisionTreeRegressor</code></a>. By using <a href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"><code>sklearn.model_selection.GridSearchCV</code></a>, we can determine the optimal value for the minimum number of samples per leaf (a regularization parameter, intended to prevent overfitting). Based on cross-validation, we can estimate the potential value of <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html"><code>sklearn.metrics.r2_score</code></a> we might achieve if we applied the decision tree to unseen data.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.tree</span> <span class="kn">import</span> <span class="n">DecisionTreeRegressor</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span>

<span class="n">dtr</span> <span class="o">=</span> <span class="n">DecisionTreeRegressor</span><span class="p">(</span><span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span>
<span class="n">param</span> <span class="o">=</span> <span class="p">{</span><span class="s">"min_samples_leaf"</span><span class="p">:</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">)}</span>
<span class="n">grid_search</span> <span class="o">=</span> <span class="n">GridSearchCV</span><span class="p">(</span>
    <span class="n">dtr</span><span class="p">,</span> <span class="n">param</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">scoring</span><span class="o">=</span><span class="s">"r2"</span><span class="p">,</span> <span class="n">return_train_score</span><span class="o">=</span><span class="bp">True</span>
<span class="p">).</span><span class="n">fit</span><span class="p">(</span><span class="n">X_past</span><span class="p">,</span> <span class="n">Y_past</span><span class="p">)</span>

<span class="n">df_cv</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">grid_search</span><span class="p">.</span><span class="n">cv_results_</span><span class="p">)</span>
    <span class="p">.</span><span class="n">sort_values</span><span class="p">(</span><span class="s">"rank_test_score"</span><span class="p">)</span>
    <span class="p">.</span><span class="nb">filter</span><span class="p">([</span><span class="s">"param_min_samples_leaf"</span><span class="p">,</span> <span class="s">"mean_test_score"</span><span class="p">,</span> <span class="s">"std_test_score"</span><span class="p">])</span>
<span class="p">)</span>
<span class="n">df_cv</span><span class="p">.</span><span class="n">head</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
</code></pre></div></div>

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>param_min_samples_leaf</th>
      <th>mean_test_score</th>
      <th>std_test_score</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>2</th>
      <td>3</td>
      <td>0.554561</td>
      <td>0.094576</td>
    </tr>
    <tr>
      <th>6</th>
      <td>7</td>
      <td>0.502175</td>
      <td>0.100091</td>
    </tr>
    <tr>
      <th>3</th>
      <td>4</td>
      <td>0.490702</td>
      <td>0.131177</td>
    </tr>
  </tbody>
</table>
</div>

<p><div align="justify">We attain a reasonable $R^2$ value, indicating that the model successfully captures the patterns in the data, despite its simplicity and the small size of the dataset.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_past</span><span class="p">,</span> <span class="n">Y_past</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Sample"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_plot</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x_plot</span><span class="p">),</span> <span class="n">c</span><span class="o">=</span><span class="s">"k"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"f(x)"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">x_plot</span><span class="p">,</span>
    <span class="n">grid_search</span><span class="p">.</span><span class="n">best_estimator_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_plot</span><span class="p">),</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"r"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Decision tree estimator"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/covariate_0_introduction/output_9_0.png" /></center></div></p>

<p><div align="justify">Visually, the model performs well around $x=0$, where there's a high density of $x$ values. As expected, the model's performance deteriorates at the fringes where fewer training examples are present.</div></p>

<p><div align="justify">Let's now imagine a scenario where circumstances have changed: the relationship between $X$ and $Y$ remains intact, but for some reason, the distribution of the variable $X$ is no longer $X\sim \mathcal{N}(0,1)$. Instead, it's given by $X\sim \mathcal{N}(2,1)$. In other words, there's a shift in the distribution.</div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">X_new</span><span class="p">,</span> <span class="n">Y_new</span> <span class="o">=</span> <span class="n">sample</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">13</span><span class="p">)</span>

<span class="n">min_X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">X_past</span><span class="p">,</span> <span class="n">X_new</span><span class="p">]))</span>
<span class="n">max_X</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">vstack</span><span class="p">([</span><span class="n">X_past</span><span class="p">,</span> <span class="n">X_new</span><span class="p">]))</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">X_past</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_X</span><span class="p">,</span> <span class="n">max_X</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span>
    <span class="n">density</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Old distribution of X"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span>
    <span class="n">X_new</span><span class="p">,</span>
    <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span>
    <span class="n">bins</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_X</span><span class="p">,</span> <span class="n">max_X</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span>
    <span class="n">density</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"New distribution of X"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Density of X"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/covariate_0_introduction/output_11_0.png" /></center></div></p>

<p><div align="justify">It is not reasonable to expect that our model will maintain the same performance as before. The estimation of the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html"><code>sklearn.metrics.r2_score</code></a> was made based on the original distribution of $X$, which has now shifted.</div></p>

<p><div align="justify">$\oint$ <em>We will delve into this in more depth in a future post in this series, but essentially, the previous model was trained to identify a function $h$ that minimizes the expected squared error in the distribution $(X_{\textrm{old}}, Y)$. Mathematically, this can be represented as:</em></div></p>

\[h* = \arg\min_{h\in\mathcal{H}}\,\mathbb{E}_{(X_{\textrm{old}}, Y)} \left(\left(h(X) - Y\right)^2\right),\]

<p><div align="justify"><em>This was done approximately, using the sample, by computing the empirical mean squared error. However, now, we are dealing with new data. Ideally, we should be minimizing:</em></div></p>

\[\mathbb{E}_{(X_{\textrm{new}}, Y)} \left(\left(h(X) - Y\right)^2\right).\]

<p><div align="justify"><em>That is, we are targeting the expected error in a different distribution.</em></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">r2_score</span>

<span class="n">x_plot_new</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">min_X</span><span class="p">,</span> <span class="n">max_X</span><span class="p">,</span> <span class="mi">1000</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_past</span><span class="p">,</span> <span class="n">Y_past</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"Old sample"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_new</span><span class="p">,</span> <span class="n">Y_new</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"New sample"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_plot_new</span><span class="p">,</span> <span class="n">f</span><span class="p">(</span><span class="n">x_plot_new</span><span class="p">),</span> <span class="n">c</span><span class="o">=</span><span class="s">"k"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"f(x)"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span>
    <span class="n">x_plot_new</span><span class="p">,</span>
    <span class="n">grid_search</span><span class="p">.</span><span class="n">best_estimator_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_plot_new</span><span class="p">),</span>
    <span class="n">c</span><span class="o">=</span><span class="s">"r"</span><span class="p">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"Decision tree estimator trained on old sample"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"x"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"y"</span><span class="p">)</span>
<span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="s">"lower left"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
</code></pre></div></div>

<p><div align="justify"><center><img src="/assets/img/covariate_0_introduction/output_13_0.png" /></center></div></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">r2_score</span><span class="p">(</span><span class="n">Y_new</span><span class="p">,</span> <span class="n">grid_search</span><span class="p">.</span><span class="n">best_estimator_</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_new</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>0.059081313039643146
</code></pre></div></div>

<p><div align="justify">As anticipated, the model's performance deteriorates when applied to the new data. It's important to remember that the relationship between $Y$ and $X$ has remained the same; only the distribution of $X$ has shifted.</div></p>

<hr />

<h2 id="identifying-covariate-shift">Identifying covariate shift</h2>

<p><div align="justify">With the initial problem established, our challenge can be summarized as follows:</div></p>

<p><div align="justify">Let $X$ and $Z$ be random variables (or vectors). Assume you independently sample $X$ $N\in\mathbb{N}^*$ times and $Z$ $M\in \mathbb{N}^*$ times, resulting in the samples $\{x_1, x_2, \cdots, x_N \} $ and $\{z_1, z_2, \cdots, z_M \} $. How can we determine if $X\sim Z$ using only these two samples? Specifically, in the context of covariate shift, we'll be comparing samples of covariates from the training phase with those in production.</div></p>

<p><div align="justify">In general, monitoring the distribution of covariates needs to be easy to implement. Simple methods are preferred over complex ones to prioritize computational efficiency. Moreover, analysis is typically performed on each covariate, identifying shifts in these marginal distributions. Among the classic univariate methods, the most prominent are:</div></p>

<ul>
  <li>
    <p><div align="justify">Comparison of statistics: means, variances, select sample quantiles etc;</div></p>
  </li>
  <li>
    <p><div align="justify">Comparison of frequencies for discrete distributions and categorical data;</div></p>
  </li>
  <li>
    <p><div align="justify">Kolmogorov-Smirnov test;</div></p>
  </li>
  <li>
    <p><div align="justify">Kullback-Leibler divergence.</div></p>
  </li>
</ul>

<p><div align="justify">This monitoring is often accompanied by analysis of the model's output distribution. For instance, if our model previously suggested that 10% of the data belonged to one class, and now it indicates 20%, we have a solid hint that the input distribution has shifted.</div></p>

<p><div align="justify">In this series of posts, I plan to introduce some slightly more unconventional methods for identifying covariate shift. Subsequently, we'll explore the problem through Vapnik's empirical risk minimization framework. From there, we'll derive an elegant method to address it, using a technique that will serve as a diagnostic tool for identifying dataset shift.</div></p>

<p><div align="justify">$\oint$ <em>Keep in mind that this is just one of the crucial elements when it comes to monitoring machine learning models. For a comprehensive guide that addresses the main potential issues, I recommend the references [<a href="#bibliography">2, 3</a>].</em></div></p>

<h2 id="bibliography"><a name="bibliography">Bibliography</a></h2>

<p><div align="justify">[1] <a href="https://mitpress.mit.edu/9780262545877/dataset-shift-in-machine-learning/">Dataset Shift in Machine Learning. The MIT Press. Joaquin Quiñonero-Candela, Masashi Sugiyama, Anton Schwaighofer and Neil D. Lawrence.</a></div></p>

<p><div align="justify">[2] <a href="https://towardsdatascience.com/monitoring-machine-learning-models-in-production-how-to-track-data-quality-and-integrity-391435c8a299">Monitoring Machine Learning Models in Production. Towards Data Science. Emeli Dral.</a></div></p>

<p><div align="justify">[3] <a href="https://developer.nvidia.com/blog/a-guide-to-monitoring-machine-learning-models-in-production/">A Guide to Monitoring Machine Learning Models in Production. NVIDIA Developer Blog. Kurtis Pykes.</a></div></p>
<hr />

<p><div align="justify">You can find all files and environments for reproducing the experiments in the <a href="https://github.com/vitaliset/vitaliset.github.io/tree/master/code/covariate_introduction">repository of this post</a>.</div></p>]]></content><author><name>Carlo Lemos</name></author><category term="[&quot;🇺🇸&quot;, &quot;🇧🇷&quot;, &quot;dataset shift&quot;]" /><summary type="html"><![CDATA[Introducing the dataset shift scenario with an illustrative case.]]></summary></entry></feed>