본문 바로가기
데이터 분석/시각화

[Matplotlib] Strip Plot(Jitter Plot)을 그려보자

by 부자 꽁냥이 2022. 9. 16.

안녕하세요~ 꽁냥이입니다. 범주형 변수와 수치형 변수를 같이 시각화할 때 Strip Plot(또는 Jitter Plot)을 이용합니다. 안타깝게도 Matplotlib에서는 Strip Plot(Jitter Plot)을 제공하고 있지 않은데요. 그래서 꽁냥이가 Matplotlib으로 하는 방법을 개발해 보았어요.

 

이번 포스팅에서는 Matplotlib을 이용한 Strip Plot(Jitter Plot)을 그리는 방법에 대해서 공유합니다.

 

- 목차 -

1. 랜덤 Strip Plot(Jitter Plot)

2. 정렬된 Strip Plot(Jitter Plot)

 

Seaborn 라이브러리에는 Strip Plot을 그리는 기능을 제공하고 있는데요. 궁금하신 분들은 아래 포스팅을 참고해주세요.

 

[Seaborn] 8. Strip Plot(Jitter Plot) 그리기 (feat. stripplot)

안녕하세요~ 꽁냥이에요. 이번 포스팅은 Seaborn에서 stripplot을 이용하여 Strip Plot(Jitter Plot)을 그리는 방법을 소개하려고 합니다. - 목차 - 1. Seaborn stripplot 기본 2. Seaborn stripplot 다양한 기능..

zephyrus1111.tistory.com


   1. 랜덤 Strip Plot(Jitter Plot)

먼저 랜덤 Strip Plot은 y좌표는 그대로 가져오고 각 x좌표 중심으로 데이터의 x좌표를 랜덤하게 할당하는 방법입니다. 아래 그림을 통해 구체적으로 설명하자면 degree라는 것을 설정하여 x좌표 양옆으로 영역(보라색 점선 사이)을 만들어줍니다. 이때 다른 범주 영역에 침범하지 않도록 degree를 정해줘야 하고요. 영역이 만들어졌다면 해당 영역 사이에서 x좌표를 할당하는 것입니다.

이제 아이디어를 알았으니 이를 코드로 만들어줘야합니다. 아래 코드가 랜덤 Jitter Plot을 그리는 함수입니다. 설명은 주석으로 대체합니다.

 

def random_jitter_plot(x, y, ax, degree=0.3, random_state=100, kwargs=None):
    assert degree < 0.5 ## 0.5를 넘어가면 다른 범주 영역 침범하므로 이를 방지한다.
    np.random.seed(random_state) ## 재생성을 위한 시드 넘버

    def jitter_number(pos, degree): ## pos 중심으로 좌우 degree 내에서 랜덤하게 x좌표 할당
        return (pos-degree+2*degree*np.random.rand(1))[0]

    xticks = range(len(set(x))) ## x눈금 좌표 
    if str(x.dtype) == 'object':
        category = list(set(x)) ## 유니크 범주
        category_to_num = dict(zip(category, xticks)) ## 범주를 숫자로
        num_to_category = dict(zip(xticks, category)) ## 숫자를 범주로 바꾸는 딕셔너리
        x_num = list(map(lambda x:category_to_num[x], x)) ## 범주를 숫자로 바꿈
    else:
        x_num = x
    x_jittered = [jitter_number(x, degree) for x in x_num] ## 새롭게 x좌표 할당

    ax.scatter(x_jittered, y, alpha=0.5, c=x_num, **kwargs)
    
    ## x축 눈금 설정
    ax.set_xticks(xticks) 
    if str(x.dtype) == 'object':
        ax.set_xticklabels([num_to_category[x] for x in xticks])
    
    ## x축 표시 범위 설정
    ax.set_xlim(-0.5, np.max(xticks)+0.5)

 

구현을 했으니 실제로 그려봐야겠죠? 먼저 Jitter Plot(Strip Plot)을 그릴 데이터를 준비합니다. 꽁냥이는 붓꽃 데이터를 사용했어요.

 

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from sklearn.datasets import load_iris

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
num_to_species = dict(zip(range(len(iris.target_names)), iris.target_names ))
df['species'] = list(map(lambda x:num_to_species[x], iris.target))

 

이제 itter Plot(Strip Plot)을 그려봅시다.

 

y = df['sepal width (cm)']
x = df['species']

fig = plt.figure(figsize=(8,8))
fig.set_facecolor('white')
ax = fig.add_subplot()
random_jitter_plot(x, y, ax, degree=0.3, random_state=100, kwargs={'cmap':'cool'})
plt.show()

 

 

랜덤 Jitter Plot(Strip Plot)이 예쁘게 그려진 것을 확인할 수 있습니다.


   2. 정렬된 Strip Plot(Jitter Plot)

정렬된 Strip Plot(Jitter Plot)은 y좌표를 히스토그램 계급으로 나누어 각 계급의 중앙점을 새로운 y좌표로 할당하고 각 계급 내에서 x축에 평행하게 그린 것을 말합니다. 이때 x 좌표는 각 계급에 포함된 데이터의 개수에 따라 균등하게 할당됩니다.

 

아이디어를 알았으면 코드로 구현하는 것이 인지상정입니다. 아래 코드는 정렬된 Strip Plot(Jitter Plot)을 그리는 코드입니다. 이 역시 설명은 주석으로 대체합니다.

 

def hist_jitter_plot(x, y, ax, degree=0.3, bins=20, kwargs=None):
    assert degree < 0.5 ## 0.5를 넘어가면 다른 범주 영역 침범하므로 이를 방지한다.
    
    def get_new_xy(y, pos, degree, bins): ## 계급별 새로운 y값과 x좌표 재 생성
        counts, edges = np.histogram(y, bins=bins) ## 계급별 데이터 개수, 계급 구간
        centres = (edges[:-1] + edges[1:])/2 ## 각 계급 구간의 중앙
        new_y = centres.repeat(counts) ## 기존 y를 해당 계급 중앙 값으로 변경
        base_width = 2*degree / counts.max() ## x좌표 간격
        offsets = np.hstack([(np.arange(c) - 0.5 * (c - 1)) for c in counts]) ## 각 계급별 위치 offset
        new_x = pos + (offsets*base_width) ## x좌표 위치로 변환
        return new_x, new_y

    xticks = range(len(set(x)))
    if str(x.dtype) == 'object':
        category = list(set(x)) ## 유니크 범주
        category_to_num = dict(zip(category, xticks)) ## 범주를 숫자로
        num_to_category = dict(zip(xticks, category)) ## 숫자를 범주로 바꾸는 딕셔너리
        x_num = list(map(lambda x:category_to_num[x], x)) ## 범주를 
    else:
        x_num = x

    ## 새로운 x, y좌표 생성
    xy_mat = np.column_stack([y, x_num])
    x_pos = []
    y_hist = []
    x_num2 = [] ## 그룹별 색상 적용을 위한 벡터
    for xt in xticks:
        temp_idx = np.where(xy_mat[:,1]==xt)[0]
        temp_x_pos, temp_y_hist = get_new_xy(xy_mat[temp_idx, 0], xt, degree, bins)
        x_pos.append(temp_x_pos)
        y_hist.append(temp_y_hist)
        x_num2.append(xy_mat[temp_idx, 1])

    x_pos = np.hstack(x_pos)
    y_hist = np.hstack(y_hist)
    x_num2 = np.hstack(x_num2)
    
    ax.scatter(x_pos, y_hist, alpha=0.5, c=x_num2, **kwargs)
    
    ## x축 눈금 설정
    ax.set_xticks(xticks)
    if str(x.dtype) == 'object':
        ax.set_xticklabels([num_to_category[x] for x in xticks])
        
    ## x축 표시 범위 설정
    ax.set_xlim(-0.5, np.max(xticks)+0.5)

 

이제 앞에서 사용한 붓꽃 데이터를 이용하여 정렬된 Strip Plot(Jitter Plot)을 그려봅시다.

 

y = df['sepal width (cm)']
x = df['species']
fig = plt.figure(figsize=(8,8))
fig.set_facecolor('white')
ax = fig.add_subplot()

hist_jitter_plot(x, y, ax, degree=0.3, bins=20, kwargs={'cmap':'cool'})
plt.show()

 

멋진 Strip Plot(Jitter Plot)이 그려졌습니다.


이번 포스팅에서는 Matplotlib을 이용한 Jitter Plot(또는 Strip Plot) 그리는 방법을 알아보았습니다. Jitter Plot이 그렇게 자주 사용되는지는 잘 모르겠지만 그래도 이번 포스팅을 해보면서 이러한 그래프가 생긴 이유, 즉 겹쳐진 데이터를 옆으로 흩어지게 해서 그 분포를 알 수 있게 하는 것이 Jitter Plot이라는 것이라는 것을 알았어요. 

 

지금까지 꽁냥이 글 읽어주셔서 감사하고 다음에도 좋은 주제로 찾아뵐 것을 약속드리며 이상 포스팅 마치겠습니다. 안녕히 계세요.

- 참고자료 -

https://stackoverflow.com/questions/8671808/matplotlib-avoiding-overlapping-datapoints-in-a-scatter-dot-beeswarm-plot


댓글


맨 위로