1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
import numpy as np import matplotlib.pyplot as plt
def createdata(): samples=np.array([[3,-3],[4,-3],[1,1],[1,2]]) labels=np.array([-1,-1,1,1]) return samples,labels
class Perceptron: def __init__(self,x,y,a=1): self.x=x self.y=y self.w=np.zeros((1,x.shape[0])) self.b=0 self.a=1 self.numsamples=self.x.shape[0] self.numfeatures=self.x.shape[1] self.gMatrix=self.cal_gram(self.x) def cal_gram(self,x): gMatrix=np.zeros((self.numsamples,self.numsamples)) for i in range(self.numsamples): for j in range(self.numsamples): gMatrix[i][j]=np.dot(self.x[i,:],self.x[j,:]) return gMatrix def sign(self,w,b,key): y=np.dot(w*self.y,self.gMatrix[:,key])+b return int(y) def update(self,i): self.w[i,]=self.w[i,]+self.a self.b=self.b+self.y[i]*self.a def cal_w(self): w=np.dot(self.w*self.y,self.x) return w def train(self): isFind=False while not isFind: count=0 for i in range(self.numsamples): tmpY=self.sign(self.w,self.b,i) if tmpY*self.y[i]<=0: print ('误分类点为:',self.x[i,:],'此时的w和b为:',self.cal_w(),',',self.b) count+=1 self.update(i) if count==0: print ('最终训练得到的w和b为:',self.cal_w(),',',self.b) isFind=True weights=self.cal_w() return weights,self.b
class Picture: def __init__(self,data,w,b): self.b=b self.w=w plt.figure(1) plt.title('Perceptron Learning Algorithm',size=14) plt.xlabel('x0-axis',size=14) plt.ylabel('x1-axis',size=14) xData=np.linspace(0,5,100) yData=self.expression(xData) plt.plot(xData,yData,color='r',label='sample data') plt.scatter(data[0][0],data[0][1],s=50) plt.scatter(data[1][0],data[1][1],s=50) plt.scatter(data[2][0],data[2][1],s=50,marker='x') plt.scatter(data[3][0],data[3][1],s=50,marker='x') plt.savefig('2d.png',dpi=75) def expression(self,x): y=(-self.b-self.w[:,0]*x)/self.w[:,1] return y def Show(self): plt.show() if __name__ == '__main__': samples,labels=createdata() myperceptron=Perceptron(x=samples,y=labels) weights,bias=myperceptron.train() Picture=Picture(samples,weights,bias) Picture.Show()
|