scikit-learnを使用した仮想通貨の相場学習

scikit-learnを使って雰囲気で機械学習をしたので覚書ということでまとめておきます。
github.com


使用したデータ

cryptowatchのapiを使用してbitflyerのJPY/BTCのデータを引っ張ってきました。
Public Market REST API - Cryptowatch

import requests
r = requests.get('https://api.cryptowatch.ch/markets/bitflyer/btcjpy/ohlc')
with open('bitflyer.json') as f:
  write(r.text)

 

前処理

このデータは '[終値, 始値, 高値, 安値, 終値, 量]'というふうにjsonでデータが入っているので学習させるために前処理をします。
前処理では[終値0, 終値1, 終値2... 終値9]というデータ、結果ラベルとして上がったら1, 下がってたら0というようなラベルを生成していきます。

def get_train_data():
    train_X = []
    train_y = []

    order_book_data = json.load(open(filename, 'r'))

    # 終値のみを取り出す
    prices = []
    for value in order_book_data['result']['60']:
        prices.append(value[1])
    print('data size:', len(prices))

    # データの生成
    input_amount = 10
    for index in range(0, len(prices) - input_amount, input_amount):
        x = prices[index:index+input_amount]
        y = 1 if x[-1] < prices[index+input_amount+1] else 0 # 上がってたら1, 下がってる or 同じだったら0
        train_X.append(x)
        train_y.append(y)

    return np.array(train_X), np.array(train_y)

評価

まずは価格データをそのまま学習させても良い結果はでないのでscale_standard関数を使って標準化します。標準化について簡単に説明するとデータを一定の規則について加工したもののことです。
今回の価格データを例に取ると100円, 200円, 150円という価格の値動きを学習するよりも、前のデータよりどれくらい利益がでたかとか何%上がったかなどの相対的な値を学習させたほうが分類においていい結果がでます。

標準化したときとしてないときの分類精度を実験してまとめてるブログがあったので参考までに。
ailaby.com

標準化したあとはcross_val_score関数を用いてスコアを検証します。
これは交差検証というもので、データを分割してモデルの汎化性能を測ります。
なぜこうするかと言うと、テストデータと学習データを分けてしまうとテストデータに対する性能しか測れないからです。
交差検証については以下で詳しく説明されているのでどうぞ。
交差検証 - Wikipedia

train_features, train_labels = get_train_data()

# train_Xとtest_Xの標準化
train_X = scale_standard(train_features)

clf = MLPClassifier()

scores = cross_val_score(clf, train_X, train_labels, cv=5)
print("Scores: ", scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

結果

以下が結果になります。

[~/bitcoin-prediction-py]$ python main.py
data size: 6000
Scores: [ 0.47933884 0.36666667 0.55833333 0.5210084 0.51260504]
Accuracy: 0.49 (+/- 0.13)

うーん、微妙。
ほとんど50%になってますね。ランダムに売買しても大差ない...。

まとめ

結構しょっぱい結果が出てしまいました。こうした方が精度上がるよというのがあれば是非おしえていただきたいです。
今後の改善方法としては海外の取引所のデータを使ってみるというのを考えています。どこかのチャットで、日本の取引所は海外の取引を基本的に追随して動いているという話を聞いたことがあります。本当かどうかはわからないですが、暇なときに試してみようと思います。