Uczenie maszynowe — regresja liniowa
Regresja
Termin regresja jest używany, gdy próbujesz znaleźć związek między zmiennymi.
W uczeniu maszynowym i w modelowaniu statystycznym ta relacja służy do przewidywania wyników przyszłych zdarzeń.
Regresja liniowa
Regresja liniowa wykorzystuje relację między punktami danych do narysowania przez nie linii prostej.
Linia ta może służyć do przewidywania przyszłych wartości.
W uczeniu maszynowym bardzo ważne jest przewidywanie przyszłości.
Jak to działa?
Python ma metody znajdowania relacji między punktami danych i rysowania linii regresji liniowej. Pokażemy Ci, jak korzystać z tych metod zamiast przechodzić przez wzór matematyczny.
W poniższym przykładzie oś x reprezentuje wiek, a oś y prędkość. Zarejestrowaliśmy wiek i prędkość 13 samochodów, które przejeżdżały przez punkt poboru opłat. Zobaczmy, czy zebrane przez nas dane można wykorzystać w regresji liniowej:
Przykład
Zacznij od narysowania wykresu punktowego:
import matplotlib.pyplot as plt
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
plt.scatter(x, y)
plt.show()
Wynik:
Przykład
Zaimportuj scipy
i narysuj linię regresji liniowej:
import matplotlib.pyplot as plt
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
Wynik:
Przykład wyjaśniony
Zaimportuj potrzebne moduły.
Możesz dowiedzieć się o module Matplotlib w naszym samouczku Matplotlib .
Możesz dowiedzieć się o module SciPy w naszym samouczku SciPy .
import matplotlib.pyplot as plt
from scipy
import stats
Utwórz tablice reprezentujące wartości osi x i y:
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
Wykonaj metodę, która zwraca kilka ważnych kluczowych wartości regresji liniowej:
slope, intercept, r,
p, std_err = stats.linregress(x, y)
Utwórz funkcję, która używa wartości slope
i
intercept
do zwracania nowej wartości. Ta nowa wartość wskazuje, gdzie na osi y zostanie umieszczona odpowiednia wartość x:
def myfunc(x):
return slope * x + intercept
Uruchom każdą wartość tablicy x przez funkcję. Spowoduje to powstanie nowej tablicy z nowymi wartościami dla osi y:
mymodel = list(map(myfunc, x))
Narysuj oryginalny wykres punktowy:
plt.scatter(x, y)
Narysuj linię regresji liniowej:
plt.plot(x, mymodel)
Wyświetl schemat:
plt.show()
R jak związek
Ważne jest, aby wiedzieć, jaki jest związek między wartościami na osi x a wartościami na osi y, jeśli nie ma związku, regresji liniowej nie można użyć do przewidzenia czegokolwiek.
Ta zależność – współczynnik korelacji – nazywa się
r
.
Wartość r
mieści się w zakresie od -1 do 1, gdzie 0 oznacza brak związku, a 1 (i -1) oznacza 100% powiązania.
Python i moduł Scipy obliczą tę wartość za Ciebie, wszystko, co musisz zrobić, to nakarmić ją wartościami x i y.
Przykład
Jak dobrze moje dane pasują do regresji liniowej?
from scipy import stats
x =
[5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
Uwaga: Wynik -0,76 pokazuje, że istnieje związek, a nie doskonały, ale wskazuje, że możemy użyć regresji liniowej w przyszłych przewidywaniach.
Przewiduj przyszłe wartości
Teraz możemy wykorzystać zebrane informacje do przewidywania przyszłych wartości.
Przykład: Spróbujmy przewidzieć prędkość 10-letniego samochodu.
Aby to zrobić, potrzebujemy tej samej myfunc()
funkcji z powyższego przykładu:
def myfunc(x):
return slope * x + intercept
Przykład
Wytypuj prędkość 10-letniego samochodu:
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
speed = myfunc(10)
print(speed)
Przykład przewidział prędkość 85,6, którą również mogliśmy odczytać z wykresu:
Złe dopasowanie?
Stwórzmy przykład, w którym regresja liniowa nie byłaby najlepszą metodą przewidywania przyszłych wartości.
Przykład
Te wartości dla osi x i y powinny skutkować bardzo złym dopasowaniem regresji liniowej:
import matplotlib.pyplot as plt
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope,
intercept, r, p, std_err = stats.linregress(x, y)
def
myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc,
x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
Wynik:
A r
związek za?
Przykład
Powinieneś otrzymać bardzo niską r
wartość.
import numpy
from scipy import stats
x =
[89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
Wynik: 0,013 wskazuje na bardzo złą zależność i mówi nam, że ten zestaw danych nie nadaje się do regresji liniowej.