訂閱
糾錯
加入自媒體

數據分析大佬用Python代碼教會你Mean Shift聚類

MeanShift算法可以稱之為均值漂移聚類,是基于聚類中心的聚類算法,但和k-means聚類不同的是,不需要提前設定類別的個數k。在MeanShift算法中聚類中心是通過一定范圍內樣本密度來確定的,通過不斷更新聚類中心,直到最終的聚類中心達到終止條件。整個過程可以看下圖,我覺得還是比較形象的。

MeanShift向量

MeanShift向量是指對于樣本X1,在以樣本點X1為中心,半徑為h的高維球區域內的所有樣本點X的加權平均值,如下所示,同時也是樣本點X1更新后的坐標。

而終止條件則是指 | Mh(X) - X |<ε,滿足條件則樣本點X1停止更新,否則將以Mh(X)為新的樣本中心重復上述步驟。

核函數

核函數在機器學習(SVM,LR)中出現的頻率是非常高的,你可以把它看做是一種映射,是計算映射到高維空間之后的內積的一種簡便方法。在這個算法中將使用高斯核,其函數形式如下。

h表示帶寬,當帶寬h一定時,兩個樣本點距離越近,其核函數值越大;當兩個樣本點距離一定時,h越大,核函數值越小。核函數代碼如下,gaosi_value為以樣本點X1為中心,半徑為h的高維球范圍內所有樣本點與X1的高斯核函數值,是一個(m,1)的矩陣。

def gaussian_kernel(self,distant):    m=shape(distant)[1]#樣本數    gaosi=mat(zeros((m,1)))    for i in range(m):        gaosi[i][0]=(distant.tolist()[0][i]*distant.tolist()[0][i]*(-0.5)/(self.bandwidth*self.bandwidth))        gaosi[i][0]=exp(gaosi[i][0])    q=1/(sqrt(2*pi)*self.bandwidth)    gaosi_value=q*gaosi    return gaosi_value

MeanShift向量與核函數

在01中有提到MeanShift向量是指對于樣本X1,在以樣本點X1為中心,半徑為h的高維球區域內的所有樣本點X的加權平均值。但事實上是不同點對于樣本X1的貢獻程度是不一樣的,因此將權值(1/k)更改為每個樣本與樣本點X1的核函數值。改進后的MeanShift向量如下所示。

其中

就是指高斯核函數,Sh表示在半徑h內的所有樣本點集合。

MeanShift算法原理

在MeanShift算法中實際上利用了概率密度,求得概率密度的局部最優解。

對于一個概率密度函數f(x),已知一個概率密度函數f(X),其核密度估計為

其中K(X)是單位核,概率密度函數f(X)的梯度估計為

其中G(X)=-K'(X)。第一個中括號是以G(X)為核函數對概率密度的估計,第二個中括號是MeanShift 向量。因此MeanShift向量是與概率密度函數的梯度成正比的,總是指向概率密度增加的方向。

而對于MeanShift向量,可以將其變形為下列形式,其中mh(x)為樣本點X更新后的位置。

MeanShift算法流程

在未被標記的數據點中隨機選擇一個點作為起始中心點X;

找出以X為中心半徑為radius的區域中出現的所有數據點,認為這些點同屬于一個聚類C。同時在該聚類中記錄數據點出現的次數加1。

以X為中心點,計算從X開始到集合M中每個元素的向量,將這些向量相加,得到向量Mh(X)。

mh(x) =Mh(X) + X。即X沿著Mh(X)的方向移動,移動距離是||Mh(X)||。

重復步驟2、3、4,直到Mh(X)的很小(就是迭代到收斂),記住此時的X。注意,這個迭代過程中遇到的點都應該歸類到簇C。

如果收斂時當前簇C的center與其它已經存在的簇C2中心的距離小于閾值,那么把C2和C合并,數據點出現次數也對應合并。否則,把C作為新的聚類。

重復1、2、3、4、5直到所有的點都被標記為已訪問。

分類:根據每個類,對每個點的訪問頻率,取訪問頻率最大的那個類,作為當前點集的所屬類。

TIPS:每一個樣本點都需要計算其漂移均值,并根據計算出的漂移均值進行移動,直至滿足終止條件,最終得到的均值漂移點為該點的聚類中心點。

MeanShift算法代碼

from numpy import *from matplotlib import pyplot as plt
class mean_shift():    def __init__(self):        #帶寬        self.bandwidth=2        #漂移點收斂條件        self.mindistance=0.001        #簇心距離,小于該值則兩簇心合并        self.cudistance=2.5
   def gaussian_kernel(self,distant):        m=shape(distant)[1]#樣本數        gaosi=mat(zeros((m,1)))        for i in range(m):            gaosi[i][0]=(distant.tolist()[0][i]*distant.tolist()[0][i]*(-0.5)/(self.bandwidth*self.bandwidth))            gaosi[i][0]=exp(gaosi[i][0])        q=1/(sqrt(2*pi)*self.bandwidth)        gaosi_value=q*gaosi        return gaosi_value
   def load_data(self):        X =array([    [-4, -3.5], [-3.5, -5], [-2.7, -4.5],    [-2, -4.5], [-2.9, -2.9], [-0.4, -4.5],    [-1.4, -2.5], [-1.6, -2], [-1.5, -1.3],    [-0.5, -2.1], [-0.6, -1], [0, -1.6],    [-2.8, -1], [-2.4, -0.6], [-3.5, 0],    [-0.2, 4], [0.9, 1.8], [1, 2.2],    [1.1, 2.8], [1.1, 3.4], [1, 4.5],    [1.8, 0.3], [2.2, 1.3], [2.9, 0],    [2.7, 1.2], [3, 3], [3.4, 2.8],    [3, 5], [5.4, 1.2], [6.3, 2],[0,0],[0.2,0.2],[0.1, 0.1],[-4, -3.5]])        x,y=[],[]        for i in range(shape(X)[0]):            x.append(X[i][0])            y.append(X[i][1])        plt.scatter(x,y,c='r')        # plt.plot(x, y)        plt.show()        classlable=mat(zeros((shape(X)[0],1)))        return  X,classlable
   def distance(self,a,b):        v=a-b        return sqrt(v*mat(v).T).tolist()[0][0]    def shift_point(self,point,data,clusterfrequency):        sum=0        n=shape(data)[0]        ou=mat(zeros((n,1)))        t=mat(zeros((n,1)))        newdata=[]        for i in range(n):            # print(self.distance(point,data[i]))            d=self.distance(point,data[i])            if d<self.bandwidth:                ou[i][0] =d                t[i][0]=1                newdata.append(data[i])                clusterfrequency[i]=clusterfrequency[i]+1        gaosi=self.gaussian_kernel(ou[t==1])        meanshift=gaosi.T*mat(newdata)        return meanshift/gaosi.sum(),clusterfrequency
   def group2(self,dataset,clusters,m):        data=[]        fre=[]        for i in clusters:            i['data']=[]            fre.append(i['frequnecy'])        for j in range(m):            n=where(array(fre)[:,j]==max(array(fre)[:,j]))[0][0]            data.append(n)            clusters[n]['data'].append(dataset[j])        print("一共有%d個簇心" % len(set(data)))        # print(clusters)        # print(data)        return clusters
   def plot(self,dataset,clust):        colors = 10 * ['r', 'g', 'b', 'k', 'y','orange','purple']        plt.figure(figsize=(5, 5))        plt.xlim((-8, 8))        plt.ylim((-8, 8))        plt.scatter(dataset[:, 0],dataset[:, 1], s=20)        theta = linspace(0, 2 * pi, 800)        for i in range(len(clust)):            cluster = clust[i]            data = array(cluster['data'])            if len(data):                plt.scatter(data[:, 0], data[:, 1], color=colors[i], s=20)            centroid =cluster['centroid'].tolist()[0]            plt.scatter(centroid[0], centroid[1], color=colors[i], marker='x', s=30)            x, y = cos(theta) * self.bandwidth + centroid[0], sin(theta) *  self.bandwidth  + centroid[1]            plt.plot(x, y, linewidth=1, color=colors[i])        plt.show()
   def mean_shift_train(self):        dataset, classlable = self.load_data()        m = shape(dataset)[0]        clusters = []        for i in range(m):            max_distance = inf            cluster_centroid = dataset[i]            # print(cluster_centroid)            cluster_frequency =zeros((m,1))            while max_distance>self.mindistance:                w,cluster_frequency = self.shift_point(cluster_centroid,dataset,cluster_frequency)                dis = self.distance(cluster_centroid, w)                if dis < max_distance:                    max_distance = dis                    # print(max_distance)                cluster_centroid = w            has_same_cluster = False            for cluster in clusters:                if self.distance(cluster['centroid'],cluster_centroid)<self.cudistance:                    cluster['frequnecy']=cluster['frequnecy']+cluster_frequency                    has_same_cluster=True                    break            if not has_same_cluster:                clusters.append({'frequnecy':cluster_frequency,'centroid':cluster_centroid})        clusters=self.group2(dataset,clusters,m)        print(clusters)        self.plot(dataset,clusters)
if __name__=="__main__":    shift=mean_shift()    shift.mean_shift_train()

得到的結果圖如下。

之后還會詳細解說K-means聚類以及DBSCAN聚類,敬請關注。

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權或其他問題,請聯系舉報。

發表評論

0條評論,0人參與

請輸入評論內容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續

暫無評論

暫無評論

文章糾錯
x
*文字標題:
*糾錯內容:
聯系郵箱:
*驗 證 碼:

粵公網安備 44030502002758號

pc28am参考结果