必威体育Betway必威体育官网
当前位置:首页 > IT技术

python实现K_mean算法

时间:2019-07-05 02:45:29来源:IT技术作者:seo实验室小编阅读:84次「手机版」
 

1653

# encoding: utf-8
'''
#!/usr/bin/env Python
@author: yudian
@contact: [email protected]
@file: k_means.py
@time: 2018/12/25 0:16
@desc:    JUST WaiTING!!!
    程序功能:动态演示K_mean聚类过程。

    可调参数: k(全局变量):聚类中心点数目
               number(get_datas()):生成的点的数量, 最好为7的倍数
               bigCircleTimes: 重复整个聚类过程的次数;
               smallCircleTimes: 单次聚类最多迭代次数。        上两个参数都是函数k_mean_cluster()的局部变量。

    各函数功能:    get_datas(): 得到待分类数据点。所有的带领都均匀分布在[0,0],[1,1],[0,5],[3,5],[2,3],[8,10],[12,12]周围
                                 输入: number(可选):代表输入点的数目
                                 输出: dataArray: number*2的矩阵。类型为numpy.ndarray

                    myShow():    显示数据。根据传入的参数执行不同的显示功能。
                                只有data: 显示初始未分类数据。
                                data, centerArray, indexMatrix: 显示最终分类结果
                                data, centerArray, indexMatrix, initArray: 动态重现最优聚类过程。
                                输入: data: 待显示的数据
                                       centerArray: k个聚类中心点的坐标向量。
                                       indexMatrix: data属于那个聚类中心的指示矩阵
                                       initArray: 最优聚类过程中的初始化聚类中心矩阵。
'''
import numpy as np
import time
import matplotlib.pyplot as plt

def get_datas(number = 490):
    center = np.array([[0,0],[1,1],[0,5],[3,5],[2,3],[8,10],[12,12]])
    center.reshape(-1,2)
    centerNumber = center.shape[0]
    dataArray = np.empty((number,2))
    for i in (range(int(number/centerNumber))):
        randomData = np.random.random_sample((centerNumber,2))
        dataArray[i*centerNumber:(i+1)*centerNumber] = center + randomData
    # myShow(dataArray)
    return dataArray

def myShow(data, centerArray=None, indexMatrix = None, initArray = None):
    if type(centerArray) == type(None):
        print(type(centerArray))
        x = data[:, 0]
        y = data[:, 1]
        plt.scatter(x, y, marker='o', s = 5, c = 'black')
        plt.show()
    else:
        if type(initArray) == type(None):
            print(centerArray)
            colorList = ['red', 'blue', 'green', 'black', 'aliceblue', 'coral', 'firebrick', 'ivory', 'linen', 'mintcream'] 
    # choose colors website: HTTPs://www.cnblogs.com/darkknightzh/p/6117528.HTML
            plt.ion()
            dataX = data[:, 0]
            dataY = data[:, 1]
            plt.scatter(dataX, dataY, marker='o', s = 5, c = 'black')
            centerArray = centerArray.reshape(centerArray.shape[0], 2)
            centerX = centerArray[:, 0]
            centerY = centerArray[:, 1]
            plt.scatter(centerX, centerY, marker = '*', s = 100, c = 'red')
            for i in range(centerArray.shape[0]):
                centerMatrix = data[np.argwhere(indexMatrix == i)]
                print(centerMatrix.shape)
                centerMatrixX = centerMatrix[:, 0, 0]
                centerMatrixY = centerMatrix[:, 0, 1]
                plt.scatter(centerMatrixX, centerMatrixY, marker = 'o', s = 5, c = colorList[i])
            plt.pause(1)
            # plt.cla()
        else:
            plt.ion()
            print(initArray)
            centerArray = initArray
            for localTimes in range(30):
                oldCenterArray = centerArray.copy()
                distMatrix = get_distance_matrix(data, centerArray)
                indexMatrix = cluster(distMatrix)
                myShow(data, centerArray, indexMatrix)
                plt.cla()
                centerArray = find_next_center(data, indexMatrix, k)
                myShow(data, centerArray, indexMatrix)
                plt.pause(1)
                if(centerArray == oldCenterArray).all():
                    plt.pause(3)
                    break
                plt.cla()


def k_mean_cluster(data,k):
    bigCircleTimes = 50
    smallCircleTimes = 30
    totalDist = np.inf
    for times in range(bigCircleTimes):
        number = data.shape[0]
        centerArray = data[np.random.randint(0, number, size = k)]
        initArray = centerArray.copy()
        for localTimes in range(smallCircleTimes):
            oldCenterArray = centerArray.copy()
            distMatrix = get_distance_matrix(data, centerArray)
            indexMatrix = cluster(distMatrix)
            centerArray = find_next_center(data, indexMatrix, k)
            if((oldCenterArray == centerArray).all()):
                print('total have done.')
                break
        if(totalDist > np.sum(distMatrix)):
            totalDist = np.sum(distMatrix)
            holdOnCenterArray = centerArray
            holdOnIndexMatrix = indexMatrix
            holdOnInitArray = initArray
    print(totalDist)
    myShow(data, holdOnCenterArray, holdOnIndexMatrix, holdOnInitArray)


def find_next_center(data, indexMatrix, k):
    assert(data.shape[0] == indexMatrix.shape[0])
    newCenterArray = np.empty(shape=(k, 2))
    for i in range(k):
        tempArray = data[np.argwhere(indexMatrix == i)[:, 0]]
        if( not tempArray.size):
            print("have center no data.Need change center.")
            newCenterArray[i] = data[np.random.randint(0, data.shape[0])]
        else:
            newCenterArray[i] = np.sum(tempArray, axis = 0)/tempArray.shape[0]
    return newCenterArray

def get_distance_matrix(data, initCenter):
    centerX = initCenter[:, 0].reshape(1, -1)
    centerY = initCenter[:, 1].reshape(1, -1)
    dataX = data[:, 0].reshape(-1, 1)
    dataY = data[:, 1].reshape(-1, 1)
    distMatrix = np.power((centerX - dataX), 2) + np.power((centerY - dataY), 2)
    return distMatrix

def cluster(distMatrix):
    number = distMatrix.shape[0]
    minMatrix = np.min(distMatrix, axis = 1)
    indexMatrix = np.empty(shape = (number,))
    for i in range(number):
        try:
            indexMatrix[i] = np.argwhere(distMatrix[i] == minMatrix[i])[0][0]
        except IndexERROR as e:
            print(minMatrix)
            raise Exception
    return indexMatrix


global k
k = 3

if __name__ == "__main__":
    data = get_datas()
    k_mean_cluster(data, k)

相关阅读

Svm算法原理及实现

Svm(support Vector Mac)又称为支持向量机,是一种二分类的模型。当然如果进行修改之后也是可以用于多类别问题的分类。支持向量机可

Floyd algorithm!!!!!(万恶的弗洛伊德算法)

曾经有位滑稽的博主说过:搜索就是优雅的暴力。今天他又要说,DP就是优雅地搜索。 不是每一个弗洛伊德都写算法,也不是写算法的都叫弗

面试中常见的数据结构与算法

第二章排序 2.1 O(n2) 算法 给定一数组,其大小为8个元素,数组内的数据无序。 6 3 5 7 0 4 1 2 冒泡排序:两两比较,将两者较少的升

机器学习算法(8)之多元线性回归分析理论详解

前言:当影响因变量的因素是多个时候,这种一个变量同时与多个变量的回归问题就是多元回归,分为:多元线性回归和多元非线性回归。线性回

图的遍历之DSF深度优先算法6.2.1(网络整理)

图的遍历之深度优先算法伪代码描述(和树的前序遍历相似,实际上树可以看成特殊的图:N个顶点有N-1条边,不曾在回路!即树是图连通中最少边

分享到:

栏目导航

推荐阅读

热门阅读