{ "cells": [ { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-10-12T15:49:15.445719Z", "start_time": "2019-10-12T15:49:15.442060Z" } }, "source": [ "(ch4)=\n", "# Gradient Descent: Part II\n", "\n", "{ref}`Chapter 3` showed how to leverage the gradient descent approach to find the minimal point for a given function. This chapter shows how the gradient descent is used in linear regression. In the case of linear regression, we have a bunch of observed data points, and the task is to identify the underlying function that generated it. For instance, let's assume the underlying hidden function is $y=4x_1^2 + 2x_2^2$. Below code, snippet uses this function to generate 500 data points. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-11-16T01:47:56.320030Z", "start_time": "2019-11-16T01:47:55.095902Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2y
588346713602
17424466536
117-43188044
865-90-1132642
326-394610316
98646-8823952
35197-8652428
2594-627752
870-647126466
83428518338
\n", "
" ], "text/plain": [ " x1 x2 y\n", "588 34 67 13602\n", "174 24 46 6536\n", "117 -43 18 8044\n", "865 -90 -11 32642\n", "326 -39 46 10316\n", "986 46 -88 23952\n", "351 97 -86 52428\n", "259 4 -62 7752\n", "870 -64 71 26466\n", "834 28 51 8338" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "linkText": "Export to plot.ly", "plotlyServerURL": "https://plot.ly", "showLink": false }, "data": [ { "mode": "markers", "opacity": 0.5, "type": "scatter3d", "x": [ 46, -92, 9, 23, 47, -97, -48, 18, 25, -29, 67, -59, -92, 33, 25, -17, -81, -37, 90, -8, -89, 7, -65, 54, -10, -3, 7, -28, 72, -33, 16, -56, 75, -23, 69, -8, -66, 16, 96, -39, 12, 57, -4, -89, 49, -99, -40, -66, -51, -23, 37, -7, 97, -39, -20, 70, 40, 15, 11, 20, -84, 67, 49, -17, 28, -60, -43, 5, -39, -91, -92, 27, -23, 55, 68, -82, 36, -80, -62, -2, 45, -5, 53, -62, -72, 98, 97, -76, 13, -58, -52, -11, 10, 6, 14, -38, 74, -30, -64, 58, 33, -55, -70, -32, 16, -23, -58, 69, 66, -56, 99, -55, 22, 97, -12, -17, 11, -43, -99, 39, 82, -89, -16, -19, -38, -80, -33, 14, 3, 49, -60, -1, 26, 72, -39, 88, 34, -31, 33, 23, 54, 27, -84, -58, 25, 72, 18, 2, -66, 7, 38, 49, -11, 37, -1, 25, 79, -58, 43, 12, -75, 6, -91, -100, 12, 37, -84, -88, -11, -77, -62, -73, 55, 16, 24, -65, 17, 11, 98, 28, 14, -13, 85, -34, 78, 17, 8, -16, 66, 33, 33, -61, 89, -40, -18, 58, -92, 75, -50, 64, 50, 47, 60, 22, -94, -71, -25, 11, -37, 83, -26, -78, -89, 52, 40, 96, -73, -27, -59, 71, 22, -64, -8, 52, 2, -51, -11, -54, -80, 51, -51, 9, -47, -76, 75, -22, 18, 90, 92, 77, 33, 84, 8, 14, 70, 90, 42, -47, -32, 52, -7, -71, 60, 100, 60, -99, -47, 8, 56, 4, -3, 66, -49, -83, 97, 8, -78, -86, -11, -35, -8, 16, 8, 94, -51, 60, 91, -28, 19, 77, -25, 29, 27, 29, -33, 3, 13, 63, -13, -11, 85, -30, 23, 14, 22, 50, 22, -79, 76, 38, 84, -7, 84, 3, -57, 54, 6, 84, 44, -84, 52, 87, -22, -96, 51, -17, 66, -18, -82, 89, 1, 56, -5, -28, 55, -53, -39, -57, -22, -76, 42, -87, -44, 45, 90, -82, 7, -51, 45, 46, 12, 51, -96, 0, -49, -14, 47, -52, 33, 44, 95, 97, -17, -88, 68, 77, 48, -88, -91, 44, -7, 1, -17, 53, 63, -57, 97, -60, -70, 55, -12, -58, -93, 85, 37, 30, -11, -78, 1, 12, 47, -11, -85, -63, 13, 80, 20, 56, 15, 84, 13, 54, -56, -10, 8, -13, 43, 31, 17, -1, -51, -76, 10, 64, -61, -90, -53, 63, 62, -52, 45, 9, -8, 43, -4, -88, 54, -87, -58, -21, -43, -88, 48, -55, -75, 62, 44, -8, -53, -23, -90, 1, 53, 95, 9, -81, -12, 94, -51, 25, -59, -57, 3, 92, 19, 98, -16, -91, -56, -55, 48, 80, 7, -61, 31, 64, 78, 0, -76, 1, -70, -40, -73, 8, 84, 63, -21, 85, -26, 36, 66, -75, -12, 33, 58, 21, 55, -38, 36, -26, -97, 70, -60, 73, -2, 77, -42, -76, 30, 63, 95, 21, -75, 47, 2, 49, -45, -27, -10, -83, -52, 59, 42, 35, -94, -81, 2, -92, -96, 67, 10, -14, -55, 86, -88, -9, -17, 60, 21, -57, -96, -48, -68, -26, -49, -37, -37, -2, 57, 39, 49, 52, 7, 41, 94, -93, -68, -99, -69, -93, -86, 75, 97, 96, -55, -69, 77, -2, -83, -97, 69, 11, 89, 69, -54, 59, 7, -88, 41, -57, -87, 79, -4, 48, 17, 16, -15, -53, -92, 2, 60, 17, 25, 81, 82, -7, 70, -82, -65, -33, -36, -55, -77, 17, -97, -61, -29, -80, 34, -22, -64, 96, -74, 37, -70, -27, 80, 58, -15, -100, 47, -34, 62, 92, 32, 9, 30, -63, 82, -85, 62, -86, 55, -85, 65, -46, 37, -97, -46, -63, -22, 76, 7, -62, -94, -5, 32, 71, 81, -71, 22, -24, 18, 18, -56, 60, -64, -93, 3, 55, -61, 15, 68, 79, 97, -55, -14, 29, 71, 20, -61, -89, 25, 51, -1, -58, 20, 79, -30, -46, 26, 33, -10, 67, 44, 96, 42, -3, -62, 64, -19, 89, -19, -67, -86, 17, -87, 72, -15, 35, 33, -88, -3, 56, 98, -80, 78, 84, -91, -68, -92, -98, -7, 11, -44, 68, 22, 44, 69, 85, -77, -74, -83, -7, 16, 9, -21, -55, 91, -68, 95, 61, 2, 39, 78, 38, -51, 74, 90, -39, 84, -61, -3, -45, 31, 53, 25, 75, 69, -2, -52, -83, 37, 86, 91, -61, 55, 51, 37, -67, -82, 68, 1, 16, -78, -47, -97, 16, -55, 70, -57, -7, -2, 61, -47, -37, -15, -24, 50, -34, -61, -85, 24, 6, -9, -13, -99, -17, 72, -25, 28, 98, 78, 50, -88, 4, 82, -28, 17, -57, -71, 7, 9, -77, -70, 18, -59, -82, -4, 36, -80, -99, 33, 85, 40, -34, 46, 23, -79, 7, -71, 92, 4, 22, -44, 0, -95, 13, 17, 62, 45, 82, 26, 78, 7, -24, 43, 68, -61, 15, -58, -42, 20, 27, 46, 19, 3, -23, 91, 41, 14, 53, 94, -95, 28, 39, -6, 61, 80, 18, -23, -53, -40, -43, -5, 54, -59, 35, 99, -94, 38, 44, -92, -86, -58, 7, 94, -22, -58, 96, -50, -93, -57, 19, -20, -90, 29, -38, -91, -96, -64, 73, 25, 16, 97, -10, -95, 7, 61, 74, -9, -34, -96, -61, -35, -90, -64, 67, 75, 70, -39, -71, 31, 82, 21, 98, 96, -25, -11, -19, 41, 9, 99, 78, -16, 52, -25, 9, -86, -73, 23, 70, 85, 83, -8, 24, 85, 11, 20, -9, -83, -60, 48, -52, 40, 93, -18, -27, -50, -99, 86, 49, 9, 85, 96, -90, 95, 87, -57, -59, 23, -8, 0, -70, -1, -37, -49, 47, -81, 25, -87, 40, -35, 25, -79, 99, 90, 68, 56, 78, -78, -38, 65, -56, -57, 15, 3, 80, 43, -52, 75, 8, 79, -42, 11, 5, -52, -41, 85, 79, -8, 23, 82, -67, 93, -50, 46, 68, 89, -87, 86, 75, 33, 79, -53, -52, -80, 84, 0, 63 ], "y": [ 74, 38, -100, 54, 28, -42, 76, 82, 45, -17, -20, 25, 49, 60, -56, 20, -38, -4, 2, 18, 30, -73, 7, 54, 32, 86, -77, -37, -85, -73, 85, -84, -31, -37, -29, 84, 79, -13, 93, -27, 36, -6, -66, -11, 97, 45, 63, -68, 79, -82, -97, 93, 97, 15, 65, -70, 3, 57, 58, -95, -59, 52, -66, 4, -39, 21, 18, -84, -57, -80, -13, 66, 68, -32, 96, -61, 63, -86, 13, 79, 88, 33, -90, 68, -87, -92, 12, 35, 5, 40, 37, -21, 90, 54, 47, 74, -8, -26, 51, -66, 18, 85, -15, -9, -83, -34, 80, -26, 67, -11, 79, 43, -74, -36, 60, -4, -20, 18, 9, 3, -58, -30, -71, -62, 42, 58, -57, -100, 13, 94, 95, -8, 10, -32, -60, -20, -52, 74, 63, 71, -68, -88, -87, 95, 49, -57, 44, 37, 40, -62, -85, 47, -25, -10, 18, 94, -7, -22, -35, -82, -50, 67, -10, -83, 65, -82, 89, 44, -37, 13, 100, 38, 74, -40, 46, 59, -32, 42, 72, -26, -88, 28, 10, 68, 84, -14, 19, 53, 78, -86, -5, -15, 22, 25, -28, 37, 24, -100, -5, -36, 49, 48, 65, 20, -54, -42, -1, 11, -18, 74, -30, -73, -13, -72, -77, -73, -35, 11, -40, -55, 39, 74, 95, -56, 96, -99, 28, 41, -27, 63, -59, -13, 46, -51, 51, 84, 77, -69, -58, -20, -98, -13, 21, -52, 14, -86, 52, 98, 22, 45, -66, -96, 42, -25, 30, -55, 33, 51, -74, -62, 27, 60, 91, -21, 47, -24, -88, -1, -58, -64, -4, 67, 66, -98, 22, 10, 4, 55, 45, 61, 63, -96, -93, 49, -58, 78, 97, -30, 3, -51, 51, 30, 53, 82, 23, -3, -40, -13, 38, -43, -8, 47, 49, -91, 46, -98, -63, -97, -67, -48, -31, -38, 90, -81, 25, 5, 86, -85, 15, -6, 78, -100, 60, -46, -55, 57, 46, -48, 77, 93, -11, -93, -25, -61, 32, 54, 25, -2, -78, -68, 20, -81, -23, 70, -54, 26, 98, 81, 21, -26, 94, -86, 22, -25, -89, -8, -83, -80, -10, 34, 9, -91, -72, 74, 42, 66, -57, -68, -89, -42, 71, 7, -88, -81, -24, -89, 5, -6, -15, 31, 55, 56, 56, -25, -2, -35, -1, -97, 42, 38, -100, 19, -42, 42, -83, -4, 38, 79, -10, -32, 71, 39, -91, -88, -17, 49, 69, 19, -21, -18, -57, -60, -12, 26, -76, 90, -8, -44, -20, -87, 63, -22, 89, -99, 31, -28, -72, 90, 91, 82, -8, -87, -67, 46, -93, -31, -42, -10, 37, -96, 55, -24, 39, 74, -3, 68, 82, 39, 62, 0, 59, 14, -19, 85, -99, 69, -68, -25, -52, -10, -12, 60, 28, -28, 6, 78, -76, 15, -64, -43, -3, 37, -16, 72, 73, -24, -37, 50, 75, 11, -95, -84, -42, -3, 96, 79, -47, -60, -88, -79, -8, -5, -73, 72, -53, 89, 68, 53, 69, -85, 64, -81, -95, 5, -40, 79, -83, 42, 71, 11, 27, -34, -72, -96, 38, -74, -50, 45, 59, 61, -81, 95, 95, -4, 65, -33, -81, -17, 54, -71, -24, 82, -41, -39, 14, -68, 23, 68, 100, -18, 93, -25, -49, 10, -87, 11, -68, 11, -35, -64, 92, -7, -64, 44, 84, 99, -27, 70, 45, 63, 62, 30, 41, 3, -48, 25, 3, 81, -76, -34, -88, -60, 0, 38, 74, -75, -64, 32, 85, -43, -30, -44, -72, -85, -86, 77, 7, 2, 67, 28, 79, -92, 96, 85, -75, -67, 8, -22, -88, -44, 8, -20, -28, 12, 57, 38, 10, -95, 55, -45, -59, 70, 85, -47, -82, -92, 57, -64, 55, -31, 92, -18, -36, 34, 53, -12, 67, -88, 73, 94, 39, -65, 94, 47, -47, 1, -36, 24, 8, 46, 0, 72, 12, -9, 6, -18, 43, -42, -19, 37, 0, -90, -95, 93, 77, 24, 18, -57, -91, 14, 93, -39, -92, 21, 51, -57, -81, -39, 79, 65, 81, 67, -74, 38, 88, 65, 97, -60, 37, -46, 55, -5, 56, 98, -59, 82, -88, 32, 19, -39, -64, 81, 71, 19, 64, -82, 97, 99, 65, -38, 26, -75, -86, -56, 66, 98, 72, -57, -66, 56, -40, -81, 18, -90, -71, -76, -82, 79, 96, -85, -75, 22, 12, -81, 52, 53, 49, -47, 91, 95, 97, -73, -49, -72, 63, 47, 81, 68, -5, -58, 35, -30, -19, -52, 28, 38, -47, 11, -49, 36, 14, 19, 49, 39, 89, -24, 68, 14, 96, -100, -12, -100, -34, 34, -35, -74, 56, 4, 5, -23, 10, -81, -24, -59, 97, -12, -64, -72, -84, -4, 62, -32, -53, -39, -89, -13, -35, 96, 10, 1, 32, -84, 86, 81, -33, 85, 87, 73, -13, 81, 1, -37, 65, -98, 91, -26, -42, -23, -39, -96, -48, -56, -52, -98, -70, -33, -33, 20, -72, -11, -79, 49, 68, 41, 80, 3, 75, -18, -74, 1, 4, 8, 95, 7, 51, 81, 97, -88, 86, -9, -15, -100, 28, 53, -74, 11, 38, 47, 31, -88, 54, 11, -70, 27, 46, -60, -1, -81, -33, -71, -65, 79, 83, 70, -32, -11, -23, 54, 36, -38, 71, -76, 10, 9, -15, 93, -46, -3, -83, 57, 10, -70, -70, 18, 24, -17, -15, -64, 46, -90, -96, 35, 76, 22, 63, -34, 11, 88, -34, -62, -67, 0, -24, -85, 46, -7, 46, 17, 39, 82, -5, 21, -97, 45, -25, -61, 29, 70, -32, -56, 55, -48, 29, 58, -67, -82, -50, -77, 20, -16, -18, 93, 12, -13, -69, -7, 80, -59, 83, -76, 75, 37, -92, 33, 69, 86, 44, -34, -78, -7, 54, 53, 85, -22, 45, -7, -66, 97, 17, -79, 65, -25, 85, -33, 17, 10, -84, -77, 31, -9, 25, 13, -10, 3, 65, -60, -5, 59, -42, 78, -89, -45, -3, -91, 44, -76, -88, 81, -48, 51, -14, 25, 4, 83, -97, 17, 7, -98, -81, 92 ], "z": [ 19416, 36744, 20324, 7948, 10404, 41164, 20768, 14744, 6550, 3942, 18756, 15174, 38658, 11556, 8772, 1956, 29132, 5508, 32408, 904, 33484, 10854, 16998, 17496, 2448, 14828, 12054, 5874, 35186, 15014, 15474, 26656, 24422, 4854, 20726, 14368, 29906, 1362, 54162, 7542, 3168, 13068, 8776, 31926, 28422, 43254, 14338, 26672, 22886, 15564, 24294, 17494, 56454, 6534, 10050, 29400, 6418, 7398, 7212, 19650, 35186, 23364, 18316, 1188, 6178, 15282, 8044, 14212, 12582, 45924, 34194, 11628, 11364, 14148, 36928, 34338, 13122, 40392, 15714, 12498, 23588, 2278, 27436, 24624, 35874, 55344, 37924, 25554, 726, 16656, 13554, 1366, 16600, 5976, 5202, 16728, 22032, 4952, 21586, 22168, 5004, 26550, 20050, 4258, 14802, 4428, 26256, 20396, 26402, 12786, 51686, 15798, 12888, 40228, 7776, 1188, 1284, 8044, 39366, 6102, 33624, 33484, 11106, 9132, 9304, 32328, 10854, 20784, 374, 27276, 32450, 132, 2904, 22784, 13284, 31776, 10032, 14796, 12294, 12198, 20912, 18404, 43362, 31506, 7302, 27234, 5168, 2754, 20624, 7884, 20226, 14022, 1734, 5676, 652, 20172, 25062, 14424, 9846, 14024, 27500, 9122, 33324, 53778, 9026, 18924, 44066, 34848, 3222, 24054, 35376, 24204, 23052, 4224, 6536, 23862, 3204, 4012, 48784, 4488, 16272, 2244, 29100, 13872, 38448, 1548, 978, 6642, 29592, 19148, 4406, 15334, 32652, 7650, 2864, 16194, 35008, 42500, 10050, 18976, 14802, 13444, 22850, 2736, 41176, 23692, 2502, 726, 6124, 38508, 4504, 34994, 32022, 21184, 18258, 47522, 23766, 3158, 17124, 26214, 4978, 27336, 18306, 17088, 18448, 30006, 2052, 15026, 27058, 18342, 17366, 662, 13068, 28306, 27702, 16048, 13154, 41922, 40584, 24516, 23564, 28562, 1138, 6192, 19992, 47192, 12464, 28044, 5064, 14866, 8908, 38596, 17928, 41250, 16200, 45254, 11014, 5458, 23496, 7752, 1494, 24624, 26166, 28438, 42054, 1408, 39824, 29586, 7212, 13092, 288, 10002, 8968, 54552, 11372, 14600, 33156, 9186, 5494, 31158, 10438, 21796, 20214, 8166, 11084, 12204, 19494, 17676, 694, 5686, 34102, 5400, 7734, 14232, 2994, 10018, 5136, 25302, 25992, 9474, 28352, 4614, 33026, 16598, 17228, 30872, 8082, 47042, 16722, 32832, 12738, 33164, 18136, 49986, 11654, 1206, 32216, 15746, 27346, 31756, 12172, 32544, 7300, 7368, 18150, 17734, 10316, 17604, 13794, 40402, 7298, 47574, 8994, 15542, 34448, 32728, 1446, 10412, 20268, 17712, 1376, 23526, 37922, 9800, 15436, 2136, 28044, 23938, 5238, 9096, 53772, 52428, 2124, 32226, 34338, 23844, 22994, 43776, 33324, 10056, 358, 16566, 11524, 22188, 19404, 21708, 44134, 23648, 35442, 15628, 10658, 13554, 50084, 42022, 6628, 19442, 534, 24408, 454, 2498, 14886, 6756, 35172, 17126, 684, 28050, 1602, 31362, 4428, 31112, 20676, 12386, 16072, 3928, 14034, 708, 10284, 16326, 1356, 2052, 20486, 26146, 16962, 31872, 15462, 37202, 20758, 16598, 16258, 11464, 14598, 7524, 544, 8748, 11616, 47176, 11792, 34148, 14256, 16902, 15334, 31944, 25058, 31702, 24422, 16944, 18112, 16456, 27798, 15564, 32528, 15142, 20214, 40332, 17622, 28166, 4104, 35544, 13142, 20932, 19974, 14148, 3078, 44808, 1462, 47664, 14472, 36166, 20232, 12100, 16178, 25992, 918, 29334, 23446, 25906, 33584, 1250, 28512, 204, 19888, 13600, 22884, 1824, 28296, 28044, 13316, 29350, 10896, 8882, 17442, 25238, 1088, 14724, 24114, 2916, 14838, 10776, 16434, 2946, 55686, 33712, 17928, 21334, 18448, 36198, 11474, 30304, 19088, 28358, 36228, 1814, 33158, 19204, 5634, 25446, 17348, 8534, 9922, 42006, 19008, 27046, 25106, 4950, 38544, 38726, 13794, 37384, 46946, 18198, 1858, 3096, 22468, 48016, 33864, 11276, 6156, 18450, 8726, 20438, 49986, 27266, 36546, 2736, 18054, 7654, 18598, 594, 18828, 16166, 10756, 24264, 3558, 9766, 35736, 43844, 19554, 48452, 39044, 35244, 46882, 23750, 42438, 37064, 27238, 19286, 32964, 258, 30006, 45828, 35972, 582, 39876, 22916, 25776, 33526, 1654, 40776, 10774, 20934, 37964, 26764, 3426, 9234, 5764, 2274, 918, 24358, 45408, 2328, 29888, 8356, 2500, 29132, 37848, 11446, 27792, 28944, 31350, 8054, 6984, 15972, 34084, 15606, 52428, 26742, 3462, 25608, 13602, 3504, 28866, 53792, 40336, 19926, 30850, 11894, 25728, 14424, 16388, 43872, 8964, 5424, 16944, 34144, 10594, 3212, 3800, 33926, 32946, 32950, 22338, 39384, 26550, 33318, 30348, 25392, 11974, 45828, 14514, 17798, 18864, 23752, 2788, 17688, 40962, 388, 13074, 35652, 36902, 37836, 4978, 10754, 18968, 5714, 16962, 14402, 18976, 35748, 164, 16332, 14884, 11268, 18784, 25126, 37708, 12748, 4482, 6892, 20886, 4338, 14884, 47884, 20550, 27702, 11862, 14608, 2248, 31462, 20162, 8856, 20002, 7398, 17328, 18838, 12946, 43362, 20178, 3078, 27858, 24834, 14566, 40662, 12396, 20844, 45072, 9606, 49094, 27936, 3638, 9132, 10406, 31026, 6308, 31752, 45378, 39048, 39824, 30272, 33846, 21538, 42048, 51538, 10278, 1206, 15936, 31944, 20754, 27346, 27494, 31788, 25068, 33154, 42348, 6468, 9736, 19532, 12132, 18598, 41836, 24768, 39300, 28006, 664, 22284, 34418, 17328, 23852, 34386, 50832, 20534, 39474, 15852, 324, 21222, 9252, 16854, 7302, 26918, 35606, 18066, 29634, 38214, 10278, 39952, 41062, 19302, 25222, 19652, 5526, 24684, 29346, 20296, 726, 6432, 25904, 11724, 42054, 1266, 16902, 22192, 13388, 918, 4818, 17926, 24678, 6628, 10148, 2696, 28432, 24624, 15172, 48900, 4616, 2456, 2774, 11628, 45476, 1188, 20786, 3558, 3336, 51538, 25488, 16962, 49794, 352, 35088, 13504, 15268, 13028, 27852, 2244, 5942, 26758, 35442, 1634, 16374, 45328, 264, 5186, 27648, 53316, 19148, 42022, 8578, 19074, 23602, 12774, 25302, 13318, 20166, 36594, 8514, 21144, 24306, 1352, 39628, 1734, 4198, 33808, 12708, 33168, 8112, 43544, 9996, 4482, 9574, 19296, 25252, 1142, 25938, 11858, 10848, 6278, 21264, 1462, 11286, 2764, 44076, 6726, 816, 11364, 53394, 36198, 8338, 19206, 18962, 30372, 40392, 1458, 2566, 31236, 7968, 13014, 11052, 11906, 16812, 9318, 41126, 50832, 11608, 7986, 43656, 31042, 17688, 7396, 35346, 15058, 15634, 46946, 18450, 47078, 26774, 11244, 3648, 32642, 4422, 11608, 35716, 39752, 26466, 32868, 2700, 1186, 38086, 17698, 40332, 214, 28662, 28402, 524, 14424, 46664, 15532, 6052, 32978, 16834, 26148, 26732, 35800, 24516, 22614, 15396, 27864, 9702, 40728, 37106, 17988, 2796, 9132, 15702, 324, 40356, 38786, 5256, 10914, 6732, 902, 32626, 34764, 2166, 20482, 47718, 31606, 1506, 9746, 30582, 10284, 3648, 6596, 33606, 19008, 10898, 17544, 15378, 48044, 6296, 14774, 10800, 39716, 30232, 26902, 612, 29238, 46386, 32498, 48900, 37238, 26774, 25476, 13366, 2994, 16928, 21778, 9526, 20268, 13476, 11148, 38412, 2598, 36108, 12018, 19350, 3468, 29014, 39302, 41112, 37314, 13122, 36818, 32786, 7026, 31350, 14722, 13574, 1100, 14148, 37458, 9318, 10978, 23750, 594, 25164, 7074, 8934, 7300, 10866, 13686, 32428, 37132, 16098, 6166, 26914, 34518, 38468, 21552, 23952, 31618, 36292, 35478, 29976, 23750, 4388, 38742, 30054, 11394, 25698, 47432, 13122, 32804 ] } ], "layout": { "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } } } }, "text/html": [ "
\n", " \n", " \n", "
\n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from numpy.random import normal\n", "import pandas as pd\n", "import random\n", "\n", "# set seed so that generated data is repetable\n", "random.seed(10)\n", "\n", "# randomly generate 500 values for X1 and x2\n", "dataDF = pd.DataFrame({\n", " 'x1': [random.randint(-100, 100) for x in range(1000, )],\n", " 'x2': [random.randint(-100, 100) for x in range(1000)]\n", "})\n", "\n", "# compute y for given set of X1 and X2\n", "dataDF['y'] = 4 * dataDF['x1']**2 + 2 * dataDF['x2']**2 \n", "\n", "display(dataDF.sample(10))\n", "\n", "from plotly.offline import download_plotlyjs, init_notebook_mode, iplot\n", "import plotly.graph_objs as go\n", "# init_notebook_mode(connected=False)\n", "iplot([go.Scatter3d(x=dataDF['x1'], y=dataDF['x2'], z=dataDF['y'], opacity=0.5, mode='markers')], show_link=False)\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Now the challenge is can we discover the underlying function ($y=4x_1^2 + 2x_2^2$) only by using the above-generated data points .** On the surface, this doesn't look like an optimization function for which we need gradient descent approach. But let's rethink the problem again in reference to {ref}`Chapter 1`. \n", "\n", "{ref}`Chapter 1` discussed various metrics (such as MAE, R Squared, etc.) to evaluate the quality of a regression model. For the given dataset, we can come up with many different possible models such as:\n", "\n", "1. $y = 3x_1^2 + 2x_2^2$\n", "2. $y = 4x_1^2 + 2x_2^2$\n", "3. $y = 4x_1^2 + 3x_2^2$\n", "\n", "Without knowing the underlying function, how do we figure which of these models best fits our observed data points? Well, as discussed in chapter 1, we can compare predicted values to the actual values and compute various statistics such as mean absolute error, R squared, etc. The best model will be the one that minimizes our selected metrics, say MAE. \n", "\n", "For a moment, lets assume the metric we are interested is MAE (refer {ref}`Chapter 1` for details). Mathematically, we can compute MAE for each of our model as follows:\n", "\n", "1. $MAE_1 = \\frac{1}{n} \\sum_{i=1}^{n}|y - (3x_1^2 + 2x_2^2)|$\n", "2. $MAE_1 = \\frac{1}{n} \\sum_{i=1}^{n}|y - (4x_1^2 + 2x_2^2)|$\n", "3. $MAE_1 = \\frac{1}{n} \\sum_{i=1}^{n}|y - (4x_1^2 + 3x_2^2)|$\n", "\n", "Now there is an infinite number of different models that we can come up. So how do we find the one with minimum MAE? An alternative approach is to parameterize our model. All our models have the same form and can be abstractly represented as: \n", "\n", "$$y=\\theta_1x_1^2 + \\theta_2x_2^2$$\n", "\n", "For model 1, $\\theta_1 = 3$ and $\\theta_2 = 2$. For model 2, $\\theta_1 = 4$ and $\\theta_2 = 2$ and so on. Further, we can abstractly represent MAE computation as:\n", "\n", "$$MAE = \\frac{1}{m} \\sum_{i=1}^{m}|y - (\\theta_1x_1^2 + \\theta_2x_2^2)|$$\n", "\n", "Thus, our problem of finding the underlying function that fits our model can now be expressed as minimizing the above equation. Great, now this starts looking like chapter 3. We have a function that we want to minimize. We want to identify optimal $\\theta_1$ and $\\theta_2$ for which MAE is minimum. \n", "\n", "But there is a problem. ** To apply gradient descent, the function should be differentiable**. However, **the \"absolute\" function in MAE is not differentiable**. So we need to look for another metric. The reason we used the \"absolute\" function in MAE is to treat under-estimation and over-estimation equally. Another way to treat them equally is to take the square of the difference between the actual and the predicted value. \"Square\" as a function is differentiable and hence enables us to gradient descent approach for finding optimal parameters (i.e. $\\theta_1$ and $\\theta_2$). This new metric is refereed to as \"Mean Squared Error\" (or MSE) as we are taking the square of the difference between actual and predicted value. \n", "\n", "$$MSE = \\frac{1}{m} \\sum_{i=1}^{m}\\left[y - (\\theta_1x_1^2 + \\theta_2x_2^2)\\right]^2$$\n", "\n", "First derivative of the above function with respect to parameters (i.e. $\\theta_1$ and $\\theta_2$) is given as:\n", "\n", "$$MSE' = \\begin{bmatrix}\n", "-\\frac{2}{m}\\sum_{i=1}^m{[y_i - (\\theta_1x_1^2 + \\theta_2x_2^2)]x_1^2}\\\\ \n", "-\\frac{2}{m}\\sum_{i=1}^m{[y_i - (\\theta_1x_1^2 + \\theta_2x_2^2)]x_2^2}\\\\ \n", "\\end{bmatrix}$$\n", "\n", "Great! Now we have a function, its first derivative, and we want to minimize it. These are all the components that we used in the previous chapter. But before we jump onto implementing this, there is one more detail we need to consider. \n", "\n", "In chapter 3, we were evaluating gradient at a single point i.e., there was no summation operator across various data points. In this case, we are evaluating gradient across many different points of our function and then taking average of them. This is hard to imagine, and you might want to spend time trying to understand what's happening. But the below code demonstrates that taking MSE as the function to minimize we can easily identify the optimal parameters using the gradient descent approach.\n", "\n", "There is also one more improvement we can do. $\\theta_1x_1^2 + \\theta_2X_2^2$ can be abstractly represented as $\\sum_{i=1}^{n}\\theta_iX_1$ with the assumption that $X_1$ and $X_2$ are already squared for us. The advantage of this approach it makes our implementation of the algorithm independent of the form of equation as long as the form can be represented as linear combination of parameters. Thus, our new MSE and MSE' form can be expressed as:\n", "\n", "$$MSE = \\frac{1}{m} \\sum_{i=1}^{m}\\left[y - \\sum_{i=1}^{n}\\theta_iX_i\\right]^2$$\n", "\n", "and \n", "\n", "$$MSE' = \\begin{bmatrix}\n", "-\\frac{2}{m}\\sum_{i=1}^m{[y_i - \\sum_{i=1}^{n}\\theta_iX_i]x_1}\\\\ \n", "-\\frac{2}{m}\\sum_{i=1}^m{[y_i - \\sum_{i=1}^{n}\\theta_iX_i]x_2}\\\\ \n", "\\cdots \\\\ \n", "\\cdots \\\\ \n", "-\\frac{2}{m}\\sum_{i=1}^m{[y_i - \\sum_{i=1}^{n}\\theta_iX_i]x_n}\\\\ \n", "\\end{bmatrix}$$\n", "\n", "Below code snippet uses these generalized form to compute gradient. \n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-10-23T12:53:54.145521Z", "start_time": "2019-10-23T12:53:54.058510Z" } }, "outputs": [ { "data": { "text/html": [ "Optimized Parameters: [3.99999265 2.00000706]" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from copy import copy\n", "from scipy.spatial import distance\n", "from random import random\n", "from IPython.display import HTML\n", "import numpy as np\n", "\n", "\n", "# this function returns gradient of MSE \n", "def gradient(X, y, theta):\n", " num_samples = X.shape[0]\n", " \n", " # assuming function to be linear combination of parameeters \n", " # hence np.transpose(theta) * X\n", " y_prime = np.sum(np.transpose(theta)*X, axis=1)\n", " diff = y - y_prime\n", " gradient = -1. * np.sum(diff.reshape((num_samples, 1)) * X, axis=0) / float(num_samples)\n", " return gradient\n", "\n", "\n", "def minimize(X, y, eta = 0.03, max_iterations = 1000, traceback = None, stopping_threshold = 1.0e-6):\n", " \"\"\"Minimizes \n", " X -- data frame containing X1 and X2 \n", " y -- actual output values \n", " eta -- learning rate \n", " max_iterations\n", " traceback -- keep track of parameters\n", " stopping_threshold -- max distance between previous and new parameters at which it the algorithm should stop. \n", " \"\"\"\n", " \n", " # number of samples\n", " m = X.shape[0]\n", " \n", " # number of parameters\n", " c = X.shape[1]\n", " \n", " # starting point -- randomly select\n", " theta1 = [random() for i in range(c)] \n", " \n", " for iter in range(max_iterations):\n", " \n", " if traceback is not None:\n", " traceback.append(copy(theta1))\n", "\n", " # compute average gradient for all data points\n", " # and move in the opposite direction\n", " theta2 = theta1 - eta * gradient(X, y, theta1)\n", " \n", " # check if we reached stopping criteria threshold\n", " if distance.euclidean(theta2, theta1) < stopping_threshold:\n", " return theta1\n", " else:\n", " theta1 = theta2\n", " \n", " # if we reached max iterations then return current point\n", " return theta1\n", "\n", "\n", "# Note we are already taking square of X1 and X2 \n", "X = (dataDF[['x1', 'x2']] ** 2).values\n", "\n", "y = dataDF['y'].values\n", "\n", "traceback = []\n", "theta = minimize(X, y, eta=1e-8, max_iterations=10000, traceback=traceback)\n", "\n", "display(HTML(\"Optimized Parameters: {}\".format(theta)))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Isn't this magic. Our algorithm was able to correctly identify the parameters of our parabola function just by looking at the generated data points. However, this example is far from reality. In particular, I made a lot of assumptions and ignored some of the details. Below are a few things that we should think about:\n", "\n", "1. I assumed the functional form of parabola (i.e. $\\theta_1X_1^2 + \\theta_2X_2^2$). But, in reality, we don't know how features (i.e. $x_1$ and $x_2$) relate to target variable. For instance, what if I started with a functional form of $\\theta_1X_1 + \\theta_2X_2$ or $\\theta_1X_1 + \\theta_2X_2 + \\theta_3X_1X_2$. Thus, there could be an infinite functional form with a different number of parameters. So, one thing to note is that the gradient descent based approach is giving optimal parameters ($\\theta$) for a single functional form. If we change our functional form, we will have to re-run the gradient descent algorithm to identify optimal parameters for the new functional form. \n", "\n", "2. When generating the datasets (X, y), I used pure parabola function. However, in reality, the collected data has some level of noise, and there can be outliers. These outliers can significantly influence optimal parameters. As discussed in the next chapter, to deal with outliers requires some form of regularization. \n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }