Coding Memorandum

プログラミングに関する備忘録

スポンサーサイト

上記の広告は1ヶ月以上更新のないブログに表示されています。
新しい記事を書く事で広告が消せます。

SVMコード

SVM(SMO法)のコードを載せておきます。理論についてはこちらをどうぞ。


2011.9.18追記:
ここでも書いたのですが,引数 target は -1 or 1 と設定頂く方が正しいはずです。(理論的にも正しいはず)
初出時は簡単なテストデータで動作確認を行っていたためか 0 or 1 でも動作確認が取れていたよう記憶しています。本格的に学習させる場合には 0 or 1 では学習が収束しませんでした。
大まかな流れ

SMOではラグランジュ未定乗数が教師データの数量分作成されます。これらを下記の処理フローで更新していきます。

ラグランジュ乗数の更新評価は1つずつ順番に行い,全ての評価を終えたところで実際に更新が行われたかをチェックします。1件でも更新したものがあれば,更新の作業を続けていきます。

下記のコードでは,個々のラグランジュ未定乗数の更新チェックはexaminUpdate()で行っています。ここでは,KKT条件を満たすかどうかのチェックを行います。
KKT条件を満たさないとき,ラグランジュ乗数を更新させます。SMOでは2点をペアとして更新しますので,update()でペアとなるラグランジュ乗数を探し,stepSMO()で2点を更新しています。

実装上の考慮点

SMOを実装する上では,次の点の考慮が必要です。

▼ Loose KKT Conditions
Platt の本に書かれているのですが,KKT条件のチェックは緩く行います。下記のコードでは,eps,tolerance変数がその役割を担っています。

▼ 誤差値のキャッシュの更新
ラグランジュ乗数の更新時には,キャッシュしている誤差値の更新を行います。
Platt の本を含め,多くの資料で更新対象a[i]に対応する誤差は 0 にすると書かれているのですが,a[i] がclippingされるときは,必ずしも 0 にはなりません。下記コードでは,clippingされたときには実際に誤差値を求めることとしています。(Line.274)

a[i]とペアを組むa[j]に関する誤差は,基本的には誤差更新の式(Line.264)で更新可能なようですが,何かの条件で実際の誤差と合わなくなることがあるようです(条件は分かっていませんが)。下記コードでは,a[j]の誤差は更新式を使わず誤差値を実測することとしました。(Line.279)

コード

下記にSVMのコードを示します。本コードは,クラス分けは2値(0 or 1),学習/テストデータの各属性値は[0,1]に正規化した値で動作させました。その他の条件で動作するかは未評価です。

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std;

class SVM
{
public:
	// 学習
	int learning( vector< vector<double> >& train_set,	// 教師データ群
				  vector<double>& target				// 正解出力 (-1 or 1)
				);
	// 判別
	double discriminate( vector<double> test			// テストデータ
					   );
private:
	int examinUpdate(int i);		// a[i]の更新評価(KKT条件のチェック)
	int update(int i);				// a[i]との更新ペアa[j]を探す
	int stepSMO( int i, int j );	// a[i],a[j]を更新する
	double f(int i);
	double kernel( const vector<double> p, const vector<double> q );
private:
	vector<double> a;				// ラグランジュ乗数
	vector<double> w;				// 重み  a x target
	vector<int> sv_index;			// サポートベクトル(0でないa[])のインデックス
	vector<double> err_cache;		// エラー値のキャッシュ
	vector< vector<double> > m_train_set;	// 教師データ
	vector<double> m_y;			// 正解データ
	int train_set_size;				// 教師データの数量
	int data_size;					// 教師データの要素数
	double b;						// 閾値
	double C;						// 
	double eps;						// ラグランジュ乗数評価時の余裕値
	double tolerance;				// KKT条件評価時の余裕値
	double Ei, Ej;					// エラー値
};

int SVM::learning( vector< vector<double> >& train_set, vector<double>& target ){
	// 初期化
	a.resize(train_set.size());
	w.resize(train_set.size());
	err_cache.resize(train_set.size());
	sv_index.clear();
	m_train_set = train_set;
	m_y = target;
	train_set_size = train_set.size();
	if( train_set_size < 1 )
		return 0;
	data_size = train_set[0].size();
	vector< vector<double> >::iterator ts_it;
	for( ts_it = train_set.begin(); ts_it != train_set.end(); ts_it++ )
		if( ts_it->size() != data_size )
			return 0;
	if( target.size() != train_set_size )
		return 0;

	// パラメータ設定
	C = 1000.0;
	eps = 0.001;
	tolerance = 0.001;
	// 初期値設定
	fill(a.begin(), a.end(), 0.0);
	fill(err_cache.begin(), err_cache.end(), 0.0);
	b = 0.0;

	bool alldata = true;
	int changed;
	int loop = 0;
	while(1){
		changed = 0;
		if( loop > 500 )
			break;
		for( int i = 0; i < train_set_size; i++ ){
			if( alldata || (a[i] > eps && a[i] < (C-eps)) ){
				changed += examinUpdate(i);
			}
		}
		if( alldata ){
			alldata = false;
			if( changed == 0 ){
				break;
			}
		}else{
			if(changed == 0){
				alldata = true;
			}
		}
		loop++;
		printf("loop %d : changed %d\n", loop, changed);
	}

	for( int i = 0; i < train_set_size; i++ ){
		if( a[i] != 0.0 ){
			sv_index.push_back(i);
		}
	}
	// 重みw 計算
	vector<int>::iterator it,it0;
	for( it = sv_index.begin(); it != sv_index.end(); it++ ){
		w[*it] = m_y[*it] * a[*it];
	}
	return 1;
}

int SVM::examinUpdate(int i){
	double yFi;
	if( a[i] > eps && a[i] < (C-eps) ){
		Ei = err_cache[i];	// f(x)-y
	}else{
		Ei = f(i) - m_y[i];
	}
	yFi = Ei * m_y[i];		// yf(x)-1

	// KKT条件のチェック
	if( (a[i] < (C-eps) && yFi < -tolerance) || (a[i] > eps && yFi > tolerance) ){
		return update(i);
	}
	return 0;
}

int SVM::update(int i){ // a[i]
	// a[j] の決定
	double max_Ej = 0.0, Ej;
	int max_j = -1;
	// 1
	int offset = (int)(((double)rand()/(double)RAND_MAX) * (double)(train_set_size-1));
	for( int j = 0; j < train_set_size; j++ ){
		int pos = (j+ offset) % train_set_size;
		if( a[pos] > eps && a[pos] < (C-eps) ){
			Ej = err_cache[pos];
			if( fabs(Ej-Ei) > max_Ej ){
				max_Ej = fabs(Ej-Ei);
				max_j = pos;
			}
		}
	}
	if( max_j >= 0 ){
		if( stepSMO(i,max_j) == 1 ){
			return 1;
		}
	}
	// 2
	offset = (int)(((double)rand()/(double)RAND_MAX) * (double)(train_set_size-1));
	for( int j = 0; j < train_set_size; j++ ){
		int pos = (j+ offset) % train_set_size;
		if( a[pos] > eps && a[pos] < (C-eps) ){
			if( stepSMO( i, pos ) == 1 ){
				return 1;
			}
		}
	}
	// 3
	offset = (int)(((double)rand()/(double)RAND_MAX) * (double)(train_set_size-1));
	for( int j = 0; j < train_set_size; j++ ){
		int pos = (j+ offset) % train_set_size;
		if( !(a[pos] > eps && a[pos] < (C-eps)) ){
			if( stepSMO( i, pos ) == 1 ){
				return 1;
			}
		}
	}
	return 0;
}

int SVM::stepSMO( int i, int j ){
	if( i == j )
		return 0;

	double ai_old = a[i], ai_new;
	double aj_old = a[j], aj_new;
	double U,V;
	if( m_y[i] != m_y[j] ){
		U = max(0.0, ai_old - aj_old);
		V = min(C, C+ai_old - aj_old);
	}else{
		U = max(0.0, ai_old + aj_old - C);
		V = min(C,   ai_old + aj_old);
	}
	if( U == V )
		return 0;

	double kii = kernel(m_train_set[i], m_train_set[i]); 
	double kjj = kernel(m_train_set[j], m_train_set[j]);
	double kij = kernel(m_train_set[i], m_train_set[j]);
	double k =  kii + kjj - 2.0*kij;
	if( a[j] > eps && a[j] < (C-eps) ){
		Ej = err_cache[j];
	}else{
		Ej = f(j) - m_y[j];
	}

	bool bClip = false;
	if( k <= 0.0 ){
		// ai = U のときの目的関数の値
		ai_new = U;
		aj_new = aj_old + m_y[i] * m_y[j] * (ai_old - ai_new);
		a[i] = ai_new; // 仮置き
		a[j] = aj_new;
		double v1 = f(j) + b - m_y[j] * aj_old * kjj - m_y[i] * ai_old * kij;
		double v2 = f(i) + b - m_y[j] * aj_old * kij - m_y[i] * ai_old * kii;
		double Lobj = aj_new + ai_new - kjj * aj_new * aj_new / 2.0 - kii * ai_new * ai_new / 2.0 
					 -m_y[j] * m_y[i] * kij * aj_new * ai_new
					 -m_y[j] * aj_new * v1 - m_y[i] * ai_new * v2;
		// ai = V のときの目的関数の値
		ai_new = V;
		aj_new = aj_old + m_y[i] * m_y[j] * (ai_old - ai_new);
		a[i] = ai_new; // 仮置き
		a[j] = aj_new;
		v1 = f(j) + b - m_y[j] * aj_old * kjj - m_y[i] * ai_old * kij;
		v2 = f(i) + b - m_y[j] * aj_old * kij - m_y[i] * ai_old * kii;
		double Hobj = aj_new + ai_new - kjj * aj_new * aj_new / 2.0 - kii * ai_new * ai_new / 2.0 
					 -m_y[j] * m_y[i] * kij * aj_new * ai_new
					 -m_y[j] * aj_new * v1 - m_y[i] * ai_new * v2;

		if( Lobj > Hobj + eps ){
			bClip = true;
			ai_new = U;
		}else if( Lobj < Hobj - eps ){
			bClip = true;
			ai_new = V;
		}else{
			bClip = true;
			ai_new = ai_old;
		}
		a[i] = ai_old; // 元に戻す
		a[j] = aj_old;
	}else{
		ai_new = ai_old + (m_y[i] * (Ej-Ei) / k);
		if( ai_new > V ){
			bClip = true;
			ai_new = V;
		}else if( ai_new < U ){
			bClip = true;
			ai_new = U;
		}
	}
	if( fabs(ai_new - ai_old) < eps * (ai_new+ai_old+eps) ){
		return 0;
	}

	// a[j]更新
	aj_new = aj_old + m_y[i] * m_y[j] * (ai_old - ai_new);
	// b更新
	double old_b = b;
	if( a[i] > eps && a[i] < (C-eps) ){
		b += Ei + (ai_new - ai_old) * m_y[i] * kii +
				  (aj_new - aj_old) * m_y[j] * kij;
	}else if( a[j] > eps && a[j] < (C-eps) ){
		b += Ej + (ai_new - ai_old) * m_y[i] * kij +
				  (aj_new - aj_old) * m_y[j] * kjj;
	}else{
		b += (Ei + (ai_new - ai_old) * m_y[i] * kii +
			       (aj_new - aj_old) * m_y[j] * kij +
		      Ej + (ai_new - ai_old) * m_y[i] * kij +
			       (aj_new - aj_old) * m_y[j] * kjj ) / 2.0;
	}
	// err更新
	for( int m = 0; m < train_set_size; m++ ){
		if( m == i || m == j ){
			continue;
		}else if( a[m] > eps && a[m] < (C-eps) ){
			err_cache[m] = err_cache[m] + m_y[j] * (aj_new - aj_old) * kernel( m_train_set[j], m_train_set[m] )
										+ m_y[i] * (ai_new - ai_old) * kernel( m_train_set[i], m_train_set[m] )
										+ old_b - b;
		}
	}

	a[i] = ai_new;
	a[j] = aj_new;
	if( bClip  ){
		if( ai_new > eps && ai_new < (C-eps) ){
			err_cache[i] = f(i) - m_y[i];
		}
	}else{
		err_cache[i] = 0.0;
	}
	err_cache[j] =  f(j) - m_y[j];

	return 1;
}

double SVM::discriminate( vector<double> test ){
	if( test.size() != data_size ){
		return 0.0;
	}
	vector<int>::iterator it;
	double eval = 0.0;
	for( it = sv_index.begin(); it != sv_index.end(); it++ ){
		eval += w[*it] * kernel(m_train_set[*it], test);
	}
	eval -= b;

	return eval;
}

double SVM::f(int i){
	double F = 0.0;
	for( int j = 0; j < train_set_size; j++ ){
		if( a[j] == 0.0 )
			continue;
		F += a[j] * m_y[j] * kernel(m_train_set[j], m_train_set[i]);
	}
	F -= b;
	return F;
}

#define GAUSSIAN

double SVM::kernel( const vector<double> p, const vector<double> q )
{
	double r = 0.0;
#ifndef GAUSSIAN
	// 多項式カーネル
	double p = 4.0;		// Tuning Parameter
	r = 1;				// Tuning Parameter
	for( int i = 0; i < data_size; i++ ){
		r += p[i] * q[i];
	}
	r = pow( r, p );
#else
	// ガウシアンカーネル
	double delta = 1.0;		// Tuning Parameter
	for( int i = 0; i < data_size; i++ ){
		r += (p[i] - q[i]) * (p[i] - q[i]);
	}
	r = -r / (2*delta*delta);
	r = exp(r);
#endif
	return r;
}

コメント

コメントの投稿


管理者にだけ表示を許可する

トラックバック

トラックバック URL
http://msirocoder.blog35.fc2.com/tb.php/35-670be246
この記事にトラックバックする(FC2ブログユーザー)

上記広告は1ヶ月以上更新のないブログに表示されています。新しい記事を書くことで広告を消せます。