기본 콘텐츠로 건너뛰기

5. Gradient Descent(경사하강법)

Ch6. Gradient Descent(경사하강법)

1. Error Surface

회귀분석에서의 Error는 이제까지 언급함 MSE라고 생각하시면 됩니다. Error surface는 모든 가능한 가중치(일반적으로 통계학에서는 회귀계수 , 으로 표현하고, 학습쪽에서는 가중지 로 표현하고는 합니다.)들에 대한 조합에 대해 Error 값을 계산하여 시각화를 하고자 한 것입니다. 통계학에서는 최소제곱법을 이용하여 회귀계수를 추정하지만, 기계학습에서의 회귀계수 추정은 살짝 다릅니다. 이 Error Surface에서 MSE가 최소점이 되는 가중치의 조합을 찾는 방식으로 회귀계수를 추정합니다.

Error Surface는 일반적으로 볼록한 그릇 모양입니다. 그 의미는 MSE가 최소가 되는 바닥이 존재할 것이며, 그 바닥을 Global minimum이라고 합니다. 이러한 접근법을 저희는 least squares optimization이라고 합니다.

2. Gradient Descent

앞서 말했듯이, 선형 회귀분석에서 global minimum을 찾기 위하여 가능한 모든 가중치의 조합(회귀계수의 조합)을 계산하는 방식은 데이터의 크기가 방대할 경우, 연산 속도가 너무 길어지기 때문에 현실적으로 구현하기에는 어려운 점이 있습니다. 하지만, Error surface에서 global minimum이 존재한다는 사실을 기반으로, 효율성을 높이기 위해 gradient descent(경사 하강법)방법을 사용합니다.

Gradient Descent(경사 하강법)의 원리는 다음과 같습니다.

  • 앞을 보기 힘들 정도로, 안개가 낀 협곡에 있다고 가정합시다.
  • 앞이 보이지 않아, 협곡을 내려가는 길을 찾을 수가 없지만, 바로 앞에서의 협곡의 경사는 확인할 수는 있습니다.
  • 경사가 내려가면 해당 방향으로 가고, 경사가 올라가면 내려가는 경사를 찾아 방향을 정합니다.
  • 이렇게 꾸준히 조심히 내려가다 보면, 언젠가는 바닥에 도착할 것입니다.

이를 그대로 알고리즘 원리에 비유를 하면,

  • Gradient Descent는 임의의 위치에서 시작합니다.(위치 : 가중치, 회귀계수)

    • 임의로 주어지는 가중치는 범위로 주어집니다.

      • 전체 조합이 아닌, 해당 위치를 기준으로한 가중치를 범위로 잡습니다.
  • SSE(Sum of Squared Error)을 계산합니다.

    • SSE를 통해 해당 지역(가중치, 회귀계수의 범위 내)에서의 Error Surface를 찾을 수 있습니다.
  • Error surface에서 경사를 계산합니다.

    • 경사가 내려가는 곳을 방향으로 설정한 후, 새로운 Error Surface로 이동을 합니다.
  • 계산이 무수히 반복되다보면, global minumum을 찾을 수 있습니다.

    • 가중치들은 계산이 진행됨에 따라서 한 곳으로 수렴하게 되는 데, 해당 부분이 Globar minimum이 될 것입니다.

일반적으로 Gradient Descent는 다음의 그래프를 통해 이해한다면 더 수월해질 것입니다.

여기서, learning rate 라는 것이 등장하는데, 이는 쉽게 말하면 계곡에서 한걸음씩 내려갈 때 보폭을 어느 정도의 크리고 설정하는 정도라고 생각하시면 됩니다. 보폭이 크면 씩씩하게 걸을 수는 있지만, 방향을 제대로 잡지 못해 알고리즘이 잘못 작동할 가능성이 있습니다. 반대로 보폭을 너무 짧게 하면 방향은 정확하게 정해서 내려갈 수 있겠지만, 속도가 너무 느리게 됩니다. 이는 저희의 퇴근시간에 매우 방해되는 문제입니다. 이 역시도 learning rate를 잘 설정하는 센스가 중요한 영역입니다.

아래의 그래프는 가중치가 변할 때마다 Error값이 어떻게 낮아지는지 나타냈습니다. 흔히, Cost function이라고 합니다.

왼쪽은 learning rate가 너무 크게 잡혔을 경우, 방향을 잘못 잡아 오히려 올라가는 현상이 발생할 수도 있습니다. 오른쪽은 반대로 learning rate가 너무 낮게 잡혀, 내려가는 과정이 매우 뎌지게 진행됩니다.

learning rate가 내려가는 보폭을 정한다면, 방향을 정하는 것은 Delta Function이 담당합니다. 계산 방식은 편미분을 통해 진행이 되는데, 수식이 조금 많이 복잡할 수도 있지만, 간단하게 원리를 설명해드리겠습니다. 여러분들도 아시다시피 미분을 활용한 최소 혹은 최대점을 찾기 위한 방법입니다. Delta Function은 error surface가 낮아지는 방향을 제시해주는 역할을 합니다.

  • Error Loss function 정의

먼저 Error function을 정의하는데, 여기서 는 SSE라고 생각하시면 됩니다. 가 벡터 형태로 표시되어 이질감이 드는 것일 뿐, 풀어쓰면 이기 때문입니다.

  • Error Surface에서 편미분을 통해 global minimum 찾기

Loss Function을 각 가중치(기울기, 회귀계수)로 편미분을 하여 0 이되는 지점을 찾습니다. 이는 회귀분석에서 최소제곱법과 비슷한 방식이라고 생각하면 됩니다. 만약 회귀식이 Multiple인 경우에는 식의 형태는 조금 변하겠지만 원리는 바뀌지 않습니다.

복잡한 편미분을 계산하면 Delta Function이 계산이 됩니다.

  • 방향 설정 방법

이렇게 learning rate 를 곱해준 값을 더해줌으로써 방향과 보폭을 결정할 수 있습니다.

Learning rate 및 초기 가중치 설정 방법

마지막으로 최적의 learning rate 및 초기 가중치 설정 방법을 찾아주는 그렇다할 이론은 없다는 것이 학계의 정설입니다. 초기 가중치 설정은 변수를 Normalizaton을 한 후 추정하는 것이 유리합니다. 여기서 Normalization이란,

변환을 의미합니다. 변환 후에 x 값은 [0,1]사이의 상대적인 비율값을 가지게 됩니다. Learning rate 및 초기 가중치 설정은 경험적인 판단에서 결정을 내리는 것이 좋습니다.

요약하면 열심히 반복문 코드를 구성해서 실험하라는 뜻입니다.

댓글

이 블로그의 인기 게시물

6.1.2 고수들이 자주 쓰는 R코드 소개 2편 [중복 데이터 제거 방법]

Ch2. 중복데이터 제거하기 및 데이터 프레임 정렬 Ch2. 중복데이터 제거하기 및 데이터 프레임 정렬 흔하지는 않지만, 중복으로 입력되는 데이터 셋을 마주치는 일이 생기기 마련입니다. 보통 중복데이터는 데이터 수집단계에서 많이 발생합니다. 하지만 이를 하나하나 엑셀로 처리하는 것은 한계가 있기때문에, R에서 처리하는 방법에 대해 다루어 보고자 합니다. 1차원 벡터, 리스트에서의 중복 제거 A = rep(1:10, each = 2) # 1 ~ 10까지 2번씩 반복 print(A) ## [1] 1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10 # 중복 제거 unique(A) ## [1] 1 2 3 4 5 6 7 8 9 10 데이터 프레임에서의 중복 제거 다음과 같은 데이터 프레임을 예시로 삼겠습니다. 변수 설명 OBS : 번호 NAME : 환자 이름 ID : 환자 고유번호 DATE : 검사 날짜 BTW : Body total water 먼저, 환자 이름이 있고, 그 환자의 고유 ID가 있습니다. 세상에 동명이인은 많기 때문에 항상 고유 ID를 기록해둡니다. # 데이터 불러오기 DUPLICATE = read.csv("C:/R/DUPLICATED.csv") DUPLICATE ## OBS NAME ID DATE BTW ## 1 1 A A10153 2018-11-30 1 ## 2 2 A A10153 2018-11-30 3 ## 3 3 B B15432 2018-11-30 4 ## 4 4 A A15853 2018-11-29 5 ## 5 5 C C54652 2018-11-28 5 ## 6 6 C C54652 2018-11-27 6 ## 7 7 D D14...

4.4.1 R 문자열(TEXT) 데이터 처리하기 1

Ch4. 문자열 데이터 다루기 1 데이터 다운로드 링크: https://www.kaggle.com/PromptCloudHQ/imdb-data # 데이터 불러오기 DATA=read.csv("C:\\R/IMDB-Movie-Data.csv") Ch4. 문자열 데이터 다루기 1 이번에는 문자열 데이터를 처리하는 방법에 대해 다루겠습니다. 문자열을 다룰 때 기본적으로 숙지하고 있어야 하는 명령어는 다음과 같습니다. 문자열 대체 : gsub() 문자열 분리 : strsplit() 문자열 합치기 : paste() 문자열 추출 : substr() 텍스트마이닝 함수: Corpus() & tm_map(), & tdm() # 문자열 추출 substr(DATA$Actors[1],1,5) # 첫번째 obs의 Actors변수에서 1 ~ 5번째에 해당하는 문자열 추출 ## [1] "Chris" # 문자열 붙이기 paste(DATA$Actors[1],"_",'A') # 첫번째 obs의 Actors변수에서 _ A 붙이기, 기본적으로 띄어쓰기르 구분 ## [1] "Chris Pratt, Vin Diesel, Bradley Cooper, Zoe Saldana _ A" paste(DATA$Actors[1],"_",'A',sep="") # 띄어쓰기 없이 붙이기 ## [1] "Chris Pratt, Vin Diesel, Bradley Cooper, Zoe Saldana_A" paste(DATA$Actors[1],"_","Example",sep="|") # |로 붙이기 ## [1] "Chris Pratt, Vin Diesel, Bradley Cooper, Zoe Saldana|...

3. Resampling 방법론(Leave one out , Cross Validation)

Ch4. Resampling 방법론 이전 챕터에서는 앙상블에 대해 다루었습니다. 앙상블을 요약하자면, Training Set을 Resampling할 때 마다, 가중치를 조정할 것인지 말 것인지를 다루는 내용이었습니다. 이번에는 구체적으로 Resampling 방법들에 대해 다루어 보고자 합니다. 1. Resampling의 목적과 접근 방식 모형의 변동성(Variability)을 계산하기 위해서 입니다. Training Set으로 모형을 만들고, Test Set으로 Error rate를 계산하며, 이를 반복합니다. 각 실행 별, Error Rate 값이 계산이 될 것이며, 해당 Error rate의 분포를 보고 모형의 성능을 평가할 수 있습니다. Model Selection : 모형의 성능을 Resampling 방법론을 통해 평가한다면, 모델링 과정에서 어떤 변수를 넣어야 하고, 혹은 모형의 유연성(Flexibility)을 어느정도로 조절하는 것이 적당한지 결정을 할 수 있기 때문에 매우 중요한 방법론 중 하나입니다. 모형의 유연성에 대해서는 다음 챕터에서 설명하도록 하겠습니다. 2. Leave-One-Out Cross Validation(LOOCV) LOOCV는 n개의 데이터에서 1개를 Test Set으로 정하고 나머지 n-1개의 데이터로 모델링을 하는 방법을 의미합니다. LOOCV 방법은 데이터 수 n이 크다면, n번의 모델링을 진행해야되기 때문에, 시간이 오래 걸립니다. 회귀, 로지스틱, 분류모형 등에 다양하게 적용할 수 있습니다. 3. K - Fold Cross - Validation 연산시간이 오래걸린 다는 것은 곧, 작업시간이 길어진다는 의미이며 이는 곧 야근을 해야된다는 소리와 다를게 없어집니다. 그래서 시간이 오래걸리는 LOOCV를 대채하기 위하여 K-Fold Cross - Validation이 존재합니다. 위 그림은 데이터 셋을 총 4개의 Set로 구성하였습니다. Cross -...

4. 통계적 추정(점추정,구간추정)

Ch1. 점추정 추정량은 우리가 알고 싶어하는 모수를 표본들을 이용하여 단 하나의 점으로 추측하는 통계량입니다. 그 과정을 점추정(Point estimation)이라고 하며, 그렇게 얻어진 통계량을 점주청량(Point estimator)라고 합니다. 점추정량은 다양한 방식으로 구할 수 있습니다. 모평균을 추정하기 위한 표본평균 계산 각 끝의 일정 부분씩은 무시하고 나머지 표본들의 평균 계산(절삭 평균, Trimmed Mean) 등의 방법들이 있습니다. 하지만 가장 많이 쓰는 척도는 표본평균입니다. 그 이유는 대표적으로 수리적인 확장성과 표본평균의 분포를 비교적 쉽게 알 수 있다는 점을 들 수 있습니다. 점추정은 단순히 모평균을 추정하는 것만이 아닌, 회귀식을 추정하였을 때의 회귀계수도 점추정이라고 할 수 있습니다. (회귀분석은 후에 다룰 예정입니다.) 다만, 이런 점추정에도 몇 가지의 장점과 단점이 있습니다. 점추정의 장점 점추정량은 지극히 직관적이다. 통계를 모르는 누군가가 한국의 30대 여성의 평균 수입을 묻는다면 점추정량으로 즉각적인 답을 줄 수 있을 것입니다. 점추정량은 매우 직관적이며 합리적입니다. 점추정량은 우리가 원하는 수치를 대체할 구체적인 값을 제시해준다. 우리가 통계적인 모델링 혹은 함수를 작성하기 위해 30대 여성 수입의 평균치가 필요하나 모평균을 알 수 없을 때 점추정량으로 간단히 대체할 수 있습니다. 사실상 이는 대부분에 통계이론을 전개하는데 가장 중요한 역할을 합니다. 간단한 예를 말씀드리자면 모분산을 추정하기 위해서는 평균이 필요하기 때문에 표본평균을 이용합니다. 여기서 분산은 각 개별 값들이 평균에서 얼만큼 멀리 떨어져있는지에 대한 척도입니다. 그런데 우리는 '진짜 평균'을 알 수 없으니 아래 식과 같이 표본들의 평균으로 대체하는 것입니다. 여기서 평균 값을 표본평균으로 대체하였기에 표본분산은 n이 아닌 n-1으로 나누어 주게 됩니다. 이해를 돕기 위해 자유도에 대한 개념을 잠깐 다루도록 하...

3.2.3 R 시각화[ggplot2] 2편 (히스토그램, 밀도글래프, 박스플롯, 산점도)

R 데이터 시각화 2편 R 데이터 시각화 2편 데이터 다운로드 링크: https://www.kaggle.com/liujiaqi/hr-comma-sepcsv # 시각화 이전에 처리 되어 있어야 하는 시각화 DATA = read.csv('C:/R/HR_comma_sep.csv') DATA$left = as.factor(DATA$left) DATA$Work_accident = as.factor(DATA$Work_accident) DATA$promotion_last_5years = as.factor(DATA$promotion_last_5years) 히스토그램(Histogram) [연속형 변수 하나를 집계 내는 그래프, 1차원] 히스토그램은 연속형변수를 일정 범위로 구간을 만들어, x축으로 설정하고 y축은 집계된 값(Counting)을 나타내는 그래프입니다. library(ggplot2) # 기본 ggplot(DATA,aes(x=satisfaction_level))+ geom_histogram() ## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`. # 구간 수정 및 색 입히기 ggplot(DATA,aes(x=satisfaction_level))+ geom_histogram(binwidth = 0.01,col='red',fill='royalblue') # col은 테두리, fill은 채우기 밀도그래프(Density Plot)[연속형 변수 하나를 집계 내는 그래프, 1차원] 밀도그래프는 연속형변수를 일정 범위로 구간을 만들어, x축으로 설정하고 y축은 집계된 값(percentage)을 나타내는 그래프입니다. # 기본 ggplot(DATA,aes(x=satisfaction_level))+ geom_density() # 색 입히기 ggplot(DATA,a...