前回、TypeSpecを使ってOpenAPI定義を作成し、Orvalを使ってOpenAPI定義から、SWRのAPIフックを自動生成した。

今回は、その自動生成したAPIフックのMSW(Mock Service Worker)のモック機能を活用し、 StorybookからMSWのモックAPIに対してリクエストを行う方法をまとめる。

MSWとは

MSW(Mock Service Worker)はWEBアプリケーション向けのモックライブラリで、 これを使うとバックエンドAPIをモックしてフロントエンドの開発を進めることができる。

MSWはブラウザのService Workerを使用し、ブラウザ内でAPIへのリクエストをインターセプトしてモックレスポンスを返すことができる。

他の方法に比べて、以下の利点がある。

  • コード内でモックに差し替える方法に比べると、モック用の追加の実装を用意する必要がないので、コードがシンプルになる
  • JSON Serverなどの方法に比べると、フロントサーバーと別プロセスでモックサーバーを立てたりする必要がなく、環境をシンプルに保てる

なお、サーバーサイド(Node.js環境)でも同じように動作させられ、 その場合、Service Workerの代わりにNode.jsのhttpモジュールをインターセプトする。

Storybook

Storybookを使うと、UIコンポーネントを独立して開発することができる。

Next.js 13では、MSWと相性が悪いようで、 実際にNext.jsアプリケーションを立ち上げて、モックデータを取得することができなかった 。(以下のIssueのコメントよると、Pages routerでもApp routerでも動作しないとのこと)

今回のプロジェクトでは、Storybookをベースにコンポーネントを作成しており、 Storybook上でMSWを利用できれば十分だったので、StorybookからMSWを使う方法をまとめる。

Storybookへの追加

msw用のアドオンが利用できる。

MSW Storybook Addon

以下のコマンドでインストールする。

npm i msw msw-storybook-addon -D

Service workerの作成

Service Worker用のワーカースクリプトを動作させるために、ワーカースクリプトをpublic/ディレクトリに追加する。

npx msw init public/

.storybook/preview.tsの修正

以下のアドオン設定を追加する。 handlersは、この後作成する。

+ import { initialize, mswLoader } from "msw-storybook-addon";
+ import { handlers } from "../mocks/handlers";

+ // Initialize MSW
+ initialize();

const preview: Preview = {
  parameters: {
+   msw: {
+     handlers
+   }
  },
+ loaders: [mswLoader]
};

OrvalでMSWモックを生成

orval.config.tsへ追記

前回の記事の設定に、 mock: true を追加して、 SWRファイルの作成と一緒にモックも作成するように変更する。

module.exports = {
  "user-api": {
    input: {
      target: "./openapi.yml",
    },
    output: {
      mode: "split",
      target: "./api/endpoints",
      schemas: './api/models',
      client: "swr",
+     mock: true,
    },
    hooks: {
      afterAllFilesWrite: "prettier --write",
    },
  },
};

userAPI.msw.tsの生成

モックの生成を有効化した後、 npx orval を実行すると、以下の userAPI.msw.ts が生成される。 mswのモック定義が自動的に生成されていることがわかる。 モックデータは、@faker-js/fakerを使って、ランダムなデータが生成している。

/**
 * Generated by orval v6.31.0 🍺
 * Do not edit manually.
 * User API
 * OpenAPI spec version: 0.0.0
 */
  faker
} from '@faker-js/faker'
  HttpResponse,
  delay,
  http
} from 'msw'
  User
} from '../models'

export const getUsersGetUsersResponseMock = (): User[] => (Array.from({ length: faker.number.int({ min: 1, max: 10 }) }, (_, i) => i + 1).map(() => ({email: faker.word.sample(), id: faker.word.sample(), name: faker.word.sample()})))

export const getUsersGetUserResponseMock = (overrideResponse: Partial< User > = {}): User => ({email: faker.word.sample(), id: faker.word.sample(), name: faker.word.sample(), ...overrideResponse})


export const getUsersGetUsersMockHandler = (overrideResponse?: User[] | ((info: Parameters<Parameters<typeof http.get>[1]>[0]) => Promise<User[]> | User[])) => {
  return http.get('*/users', async (info) => {await delay(1000);
    return new HttpResponse(JSON.stringify(overrideResponse !== undefined 
            ? (typeof overrideResponse === "function" ? await overrideResponse(info) : overrideResponse) 
            : getUsersGetUsersResponseMock()),
      {
        status: 200,
        headers: {
          'Content-Type': 'application/json',
        }
      }
    )
  })
}

export const getUsersGetUserMockHandler = (overrideResponse?: User | ((info: Parameters<Parameters<typeof http.get>[1]>[0]) => Promise<User> | User)) => {
  return http.get('*/users/:id', async (info) => {await delay(1000);
    return new HttpResponse(JSON.stringify(overrideResponse !== undefined 
            ? (typeof overrideResponse === "function" ? await overrideResponse(info) : overrideResponse) 
            : getUsersGetUserResponseMock()),
      {
        status: 200,
        headers: {
          'Content-Type': 'application/json',
        }
      }
    )
  })
}
export const getUserAPIMock = () => [
  getUsersGetUsersMockHandler(),
  getUsersGetUserMockHandler()
]

mocks/handlers.tsの追加

mocksディレクトリを作成する。

mkdir mocks
cd mocks

先ほど作成したgetUserAPIMockを使って、mocks/handlers.tsにモックAPIのハンドラを追加する。

+ import { getUserAPIMock } from "../api/userAPI.msw.ts";
+
+ export const handlers = getUserAPIMock();

このhanlders.storybook/preview.tsのmswのhandlersの設定に渡され、 Storybook上でMSWのモックAPIを利用できるようになる。

コンポーネントの利用方法は、前回と同様、生成されたSWRのカスタムフックをReactコンポーネントで呼び出せば良い。 内部的には、MSWのモックAPIが呼び出される。

const { data, error, isLoading } = useUsersGetUsers();

モックデータをカスタムしたい場合

stringで定義した場合、 faker.word.sample() モックをカスタムしたい場合は、以下のようにpropertiesオプションでカスタムすることができる。

たとえば、名前、Email、画像URLなどをランダムなデータにしたい場合は、以下のように設定する。

module.exports = {
  "user-api": {
    input: {
      target: "./openapi.yml",
    },
    output: {
      mode: "split",
      target: "./api/endpoints",
      schemas: './api/models',
      client: "swr",
      mock: true,
+     override: {
+       mock: {
+         properties: {
+           name: () => faker.name.fullName(), // nameフィールドはランダムな名前を返すように変更
+           email: () => faker.internet.email(), // emailフィールドはランダムなメールアドレスを返すように変更
+           '/image/': () => faker.image.url(), // imageとつくフィールドは画像URLを返すように変更
+         },
+       }
+     },
      hooks: {
        afterAllFilesWrite: "prettier --write",
      },
    },
  },
};

まとめ

最終的に開発の流れは以下の通りとなる。

  1. TypeSpecでAPI定義をtspファイルで作成
  2. tspファイルからOpenAPIのyamlファイルを自動生成
  3. Orvalを使って、OpenAPIのyamlファイルからSWRフックとMSWのモック実装を自動生成
  4. APIのMSWモックを使って、Storybook上でコンポーネントの開発を進める

前回、TypeSpecでAPI定義した後、OpenAPIのYAMLファイルを自動生成できるようにした。

OpenAPI定義を生成できたので、API用のクライアントコード、 さらにはReactアプリからAPIのデータ取得用のカスタムフックも自動生成したい。

Orvalを使うとSWRのAPIアクセスのためのhookも自動生成できるようなので試してみる。

SWRとは

SWRはデータ取得のためのReact hooksライブラリで、 APIアクセスのためのキャッシュやエラー処理などの機能を提供する。

SWRを使うことで、フロントからAPIリクエストする際に必要になる機能を簡単に実装できる。

以下は、ドキュメントに記載されている例。

function Profile() {
  const { data, error, isLoading } = useSWR('/api/user', fetcher)

  if (error) return <div>failed to load</div>
  if (isLoading) return <div>loading...</div>
  return <div>hello {data.name}!</div>
}

useSWRフックを使って、/api/userエンドポイントからデータを取得している。

Orvalとは

OrvalはOpenAPI(Swagger)仕様から、TypeScriptのクライアントコードを生成するツールで、以下の特徴がある。

  1. TypeScriptモデルの生成
  2. HTTP呼び出しの生成
  3. MSW(Mock Service Worker)を使用したモックの生成

SWRのHooksも生成できる他、React Query、Angularなどのクライアントに対応している。

インストール

まずはOrvalをインストールする。

npm i orval -D

準備

生成するにあたって、TypeSpecのtspファイルとorvalの設定ファイルを用意する。

main.tsp

TypeSpecでシンプルなAPI定義を作成する。

using TypeSpec.Http;

model User {
	id: string;
	name: string;
	email: string;
}

@route("/users")
interface Users {
	@get
	getUsers(): User[];

	@get
	@route("/{id}")
	getUser(@path id: string): User;
}

前回同様の設定で、openapi.ymlを出力する。

$ tsp compile main.tsp --emit @typespec/openapi3

orval.config.ts

今回はSWRを作成したいので、次のドキュメントが参考になる。

orval.config.tsを作成して出力設定を記載する。

module.exports = {
  userstore: {
    input: {
      target: "./openapi.yml",
    },
    output: {
      mode: "split",
      target: "./api/endpoints",
      schemas: "./api/models",
      client: "swr",
    },
  },
};

今回はapiディレクトリ以下に、モデル定義はmodels、SWRファイルはendpointsディレクトリに配置されるようにした。

生成したファイルにprettierを適用

以下の設定で、生成したファイルにprettierを適用できる。

module.exports = {
  userstore: {
    input: {
      target: "./openapi.yml",
    },
    output: {
      mode: "split",
      target: "./api/endpoints",
      schemas: "./api/models",
      client: "swr",
    },
+   hooks: {
+     afterAllFilesWrite: "prettier --write",
+   },
  },
};

実行

以下のコマンドでOrvalを実行する。

$ npx orval
🍻 Start orval v6.31.0 - A swagger client generator for typescript
(node:4636) [DEP0040] DeprecationWarning: The `punycode` module is deprecated. Please use a userland alternative instead.
(Use `node --trace-deprecation ...` to show where the warning was created)
🎉 user-api - Your OpenAPI spec has been converted into ready to use orval!

warningが発生したが、api以下に生成ファイルを出力することができた。

models/user.ts

API で使用されるデータモデルの型定義を含んでいる。 TypeSpecのモデル定義が反映されていることがわかる。

/**
 * Generated by orval v6.31.0 🍺
 * Do not edit manually.
 * User API
 * OpenAPI spec version: 0.0.0
 */

export interface User {
  email: string;
  id: string;
  name: string;
}

endpoints/userAPI.ts

API エンドポイントに対応するSWRフックが自動生成される。

/**
 * Generated by orval v6.31.0 🍺
 * Do not edit manually.
 * User API
 * OpenAPI spec version: 0.0.0
 */

type AwaitedInput<T> = PromiseLike<T> | T;

type Awaited<O> = O extends AwaitedInput<infer T> ? T : never;

export const usersGetUsers = (
  options?: AxiosRequestConfig
): Promise<AxiosResponse<User[]>> => {
  return axios.get(`/users`, options);
};

export const getUsersGetUsersKey = () => [`/users`] as const;

export type UsersGetUsersQueryResult = NonNullable<
  Awaited<ReturnType<typeof usersGetUsers>>
>;
export type UsersGetUsersQueryError = AxiosError<unknown>;

export const useUsersGetUsers = <TError = AxiosError<unknown>>(options?: {
  swr?: SWRConfiguration<Awaited<ReturnType<typeof usersGetUsers>>, TError> & {
    swrKey?: Key;
    enabled?: boolean;
  };
  axios?: AxiosRequestConfig;
}) => {
  const { swr: swrOptions, axios: axiosOptions } = options ?? {};

  const isEnabled = swrOptions?.enabled !== false;
  const swrKey =
    swrOptions?.swrKey ?? (() => (isEnabled ? getUsersGetUsersKey() : null));
  const swrFn = () => usersGetUsers(axiosOptions);

  const query = useSwr<Awaited<ReturnType<typeof swrFn>>, TError>(
    swrKey,
    swrFn,
    swrOptions
  );

  return {
    swrKey,
    ...query,
  };
};

export const usersGetUser = (
  id: string,
  options?: AxiosRequestConfig
): Promise<AxiosResponse<User>> => {
  return axios.get(`/users/${id}`, options);
};

export const getUsersGetUserKey = (id: string) => [`/users/${id}`] as const;

export type UsersGetUserQueryResult = NonNullable<
  Awaited<ReturnType<typeof usersGetUser>>
>;
export type UsersGetUserQueryError = AxiosError<unknown>;

export const useUsersGetUser = <TError = AxiosError<unknown>>(
  id: string,
  options?: {
    swr?: SWRConfiguration<Awaited<ReturnType<typeof usersGetUser>>, TError> & {
      swrKey?: Key;
      enabled?: boolean;
    };
    axios?: AxiosRequestConfig;
  }
) => {
  const { swr: swrOptions, axios: axiosOptions } = options ?? {};

  const isEnabled = swrOptions?.enabled !== false && !!id;
  const swrKey =
    swrOptions?.swrKey ?? (() => (isEnabled ? getUsersGetUserKey(id) : null));
  const swrFn = () => usersGetUser(id, axiosOptions);

  const query = useSwr<Awaited<ReturnType<typeof swrFn>>, TError>(
    swrKey,
    swrFn,
    swrOptions
  );

  return {
    swrKey,
    ...query,
  };
};
  1. axios を使用して API リクエストを行う関数が定義されている。
    • usersGetUsers(): すべてのユーザーを取得
    • usersGetUser(id: string): 特定のユーザーを取得
  2. SWR フックが各エンドポイントに対して生成されている。
    • useUsersGetUsers(): すべてのユーザーを取得するフック
    • useUsersGetUser(id: string): 特定のユーザーを取得するフック

たとえば、以下のようにして、ユーザー情報を取得するできる。

const { data, error, isLoading } = useUsersGetUsers();

swr オプションを使用して、SWR フックの設定を変更できる。 たとえば、キャッシュを無効にする場合は、 enabled: false を設定する。

const { data, error, isLoading } = useUsersGetUsers({
  swr: {
    enabled: false,
  },
});

APIのBaseURLを変更

デフォルトの設定では、生成したAPIクライアントはフロントと同じドメインのエンドポイントにリクエストするようになっている。 実際にはフロントから別のドメインのAPIにリクエストしたい場合がある。

その場合、Orvalで自動生成したSWRは内部的にaxiosを使ってHTTPリクエストをしているので、 以下のようにaxiosのデフォルトのbaseURLを変更することで、リクエスト先のドメインを変更できる。

axios.defaults.baseURL = '<BACKEND URL>';

他の方法など、詳しくは以下のページを参照。

Base URL設定方法

まとめ

以下のプロセスでフロントのAPIクライアントを自動生成できることがわかった。

  1. TypeSpecのtspファイルからOpenAPI定義ファイルを生成
  2. Orvalを使用して、生成したOpenAPI定義からフロントエンド用のSWRフックを自動作成

また、MSWのモックも自動生成できるので、バックエンドが未完成の場合でも、 ひとまずTypeSpecのAPI定義を作れば、自動生成したSWRのフックでフロントエンドの開発を進められる。

このサイトは、Next.jsでブログ記事機能を作っていて、 ブログの記事はmdxファイルで作成し、マークダウンをベースに、一部Reactコンポーネントを読み込むことで記事を作成している。

今回、開発ノートというカテゴリを作って、開発関連のメモをこちらにまとめていくことにした。 その際、シンタックスハイライトが適用されるようにしたので、その方法をメモしておく。

インストール

シンタックスハイライトを実現するために、react-syntax-highlighterというライブラリを使用する。

Next.jsのプロジェクトに以下のコマンドで追加する。

npm install react-syntax-highlighter --save
npm install @types/react-syntax-highlighter --dev

設定

Next.jsでMDXを使用する際、mdx-components.tsxファイルで、 MDXコンテンツ内のHTML要素やコンポーネントをカスタマイズすることができる。

mdx-components.tsxファイルに以下の内容を追加した。

export function useMDXComponents(components: MDXComponents): MDXComponents {
  return {
    ...components,
    pre: ({ children }) => <>{children}</>, // ②
    code: ({ className, children, ...rest }) => { // ③
      const match = /language-(\w+)/.exec(className || "");
      return match ? (
        <SyntaxHighlighter language={match[1]} style={oneDark}>
          {String(children)}
        </SyntaxHighlighter>
      ) : (
        <code {...rest} className={className}> // ④
          {children}
        </code>
      );
    },
  }
}

この設定により、マークダウン内のコードブロックが自動的にシンタックスハイライトされる。

解説

  1. prismのoneDarkスタイルを使用している。
  2. mdxのコードブロックは<pre><code>...</code></pre>の構造になるが、SyntaxHighlighterがpre要素込みのHTMLを生成するため、pre要素は削除して二重にならないようにしている。
  3. code要素では、言語名が取得できた場合は、SyntaxHighlighterを適用している。
  4. コードブロックで言語が指定されない場合や、インラインコードの場合は、通常のcode要素として表示する。

使いたい背景

APIを実装する場合、個人的には、Protocol BuffersのようなDSLでAPIを定義して、それを元にコードを生成する方法が一番しっくりくる。 REST APIの場合は、OpenAPI(Swagger)で仕様を作成して、そこからドキュメントなり、コードを生成したりなどできるけど、 正直YAMLを書くのが辛くて、苦痛な作業だと感じることが多い。

OpenAPIの仕様をGUIで作成するツールも存在するので。それを使えば良いという意見もあると思う。 だけど、仕様を書くのにも読むのにも外部のGUIツールが必要になってしまうのは、何か間違ってるように感じる。 仕様はコードと同じようにテキストで書くべきものだと思う。

あるいは、バックエンドAPIから自動生成するようなツールを使う方法もある。 ただし、この方法はバックエンドの言語やフレームワークに依存してしまうし、 自分がフロントエンド担当でバックエンドを作成しない場合などは、この方法を採用しづらい。

フロントでTypeScriptでAPIの型定義を書いていると、YAML書くよりずっと良いAPI仕様書では?と思うこともしばしばあって、 TypeScript風にAPIの仕様を書ければ理想的だと思っていた。

そんな中、この課題にぴったりなTypeSpecの存在を知ったので、試してみることにした。

TypeSpecとは

TypeSpecは、Microsoft製のAPIの設計と文書化のためのDSLで、TypeScript風の構文を使用してAPIの仕様を記述できる。

公式サイト: https://typespec.io/

インストールとセットアップ

TypeSpecをグローバルにインストール

npm install -g @typespec/compiler

バージョンの確認

tsp --version

新規プロジェクトの作成。自動的にmain.tsp, package.json, tspconfig.yamlなどのファイルが作成される。

tsp init

依存関係のインストール

tsp install

基本的な使用方法

main.tspファイルを作成する。 以下は公式サイトの例として記載されている内容。

using TypeSpec.Http;

model Store {
  name: string;
  address: Address;
}

model Address {
  street: string;
  city: string;
}

@route("/stores")
interface Stores {
  list(@query filter: string): Store[];
  read(@path id: Store): Store;
}

プログラミング経験がある人であれば、特に文法を知らなくてもストアの一覧と詳細取得を行うAPIが定義されていることが理解できると思う。 OpenAPIのYAMLだとこの程度でもかなりの行数を書かないといけないので、TypeSpecの方が断然読みやすい。

OpenAPIの生成

以下のパッケージをインストールする。

npm install @typespec/openapi3

TypeSpecファイルをコンパイルして、OpenAPI仕様を生成する。

tsp compile main.tsp --emit @typespec/openapi3

tspconfig.yamlに設定を記載できる。

上記の--emitオプションは、以下の記載で毎回コマンドラインで指定する必要がなくなる。

emit:
  - "@typespec/openapi3"

--watchでファイル変更を監視して自動再コンパイルできる

tsp compile main.tsp --emit @typespec/openapi3 --watch

デフォルトだと、tsp-output/@typespec/openapi3以下にopenapi.yamlが生成される。 ルートにopenapi.yamlを生成する場合はtspconfig.yamlに以下の設定を追加する。

options:
  "@typespec/openapi3":
    emitter-output-dir: "{output-dir}"
    output-file: "openapi.yml"

output-dir: "{project-root}"

文法について

モデル定義

modelキーワードを使用して、APIで仕様するデータモデルを定義することができる。 TypeScript風にデフォルト値やOptionalなプロパティも表現できる。

model User {
  id: string;
  name: string = "Anonymous";
  description?: string;
}

デコレーターを使って、ドキュメントやバリデーション情報を追加することができる。

model User {
  @doc("ユーザーID")
  @maxLength(50)
  username: string;

  @doc("ユーザーのメールアドレス")
  @pattern("^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$")
  email: string;

  @doc("ユーザーの年齢(18歳以上)")
  @minValue(18)
  age: int32;
}

enumを使って、列挙型を定義することもできる。

enum UserRole {
  Admin,
  Editor,
  Viewer
}

model User {
  role: UserRole;
}

オペレーション (operation)

REST APIでなんらかのエンドポイントを定義したいときは、opキーワードを使用する。

例えば、ユーザー情報を取得するGETエンドポイントを定義する場合は、以下のようになる。

op getUser(userId: string): User;

デコレーターを使って、オペレーションの付加情報を追加できる。

@get
@route("/users/{userId}")
op getUser(userId: string): User;

インターフェース (interface)

interfaceキーワードを使用してインターフェースを定義できる。 インターフェースを使って、複数のオペレーションをグループ化することができる。

ユーザー情報のCRUD操作をまとめたUsersインターフェースを定義する例。

interface Users {
	@get
	@route("/users")
	listUsers(limit?: int32, role?: string): User[];

	@get
	@route("/users/{userId}")
	getUser(userId: string): User;

	@post
	@route("/users")
	createUser(
		user: {
			name: string;
			email: string;
		},
	): User;

	@patch
	@route("/users/{userId}")
	updateUser(
		userId: string,
		user: {
			name?: string;
			email?: string;
		},
	): User;

	@delete
	@route("/users/{userId}")
	deleteUser(userId: string): void;
}

まとめ

実際に、API仕様書を書く案件で使ってみたところ、 ほとんどTypeScriptコードを書くのと同じように、API仕様書を作ることができた。 VS Codeのエクステンションを使って、静的型チェックのサポートを受けることもできるので、 文法ミスなどもすぐに気づくことができるのが良かった。

プログラミングには 驚き最小の法則 という考え方があるけど、 TypeSpecは、APIの仕様書を書くという観点で、TypeScriptプログラマーにとっての驚き最小を実現しているように感じた。

今後、TypeSpecがどれくらい普及していくかはわからないけど、 個人的には、今後のプロジェクトで積極的に使っていきたいと思う。

同時並行で複数案件を進めているので、ここ1年くらいは結構幅広い技術を使ったと思う。思い出せるものを雑多に列挙すると以下の通り。

  • モバイルアプリ
    • Flutter
    • React Native, Expo
  • フロント
    • Next.js (TypeScript)
    • Astro
    • Tailwind CSS
  • バックエンド
    • Go (echo)
    • Python (FastAPI)
    • Ruby (Rails)
  • 組み込み
    • Platformio (C++)
    • ESP32
  • クラウドサービス
    • AWS (ECS, IoT Core)
    • GCP (Cloud Run)
    • Firebase (Auth/Firestore)
    • Planetscale (MySQL)
    • Cloudflare
  • その他
    • GraphDB (Neo4j)
    • gRPC (Go, Flutter)
    • MQTT (ESP32)
    • Terraform
    • Chrome Extension
    • Figma (webデザイン)

ここ数年はモバイルアプリ関連が多めだったけど、最近はバックエンドの案件も多かった。

振り返ってみると、この中だと特に難しいのがC++のファームウェア開発だった(今もやってるけど)。デバッグが大変で、エラーの原因調査にも時間がかかるし、メモリも使いすぎないように注意しないといけないし、ライブラリに頼れず自分で実装しなければいけない部分も多かったりで、だいたい思ったよりも時間がかかる。もうちょっと修行が必要。

モバイルアプリはこれまでReact Nativeを使ってきたけど、去年からFlutterもがっつり使い始めた。クロスプラットフォームアプリを作る場合どちらが好みかは今は何とも言えないけど、Flutterは安定感があって結構いい感じ。

バックエンドに関しては、選択肢が多くて困ってしまうのだけど、特に理由がないときは、迷うところが少なくて、いろんな意味で速いGoが好みかもしれない。数日前にも個人用途でちょっとしたAPIを作ったのだけど、結局Goにした。

あと、開発環境としては、IntelliJかCLion (C++) を使って開発している。もう周りの人はみんなVSCodeを使っているのだけど、やっぱ個人的にはJetBrainsの方が使いやすいように思う。

GitHub Actionを新しく作るとき、GitHub上でテストしては失敗を繰り返してしまう。actを使って、ある程度ローカルで動作検証した上で、GitHubで動かせるようにする。

nektos/act GitHub Repository

インストール

実行にはDocker環境が必要。

macOSの場合、以下でインストールできる。

$ brew install act

mac以外の環境で、goが入ってる場合は、以下でもインストールできる。

$ go install github.com/nektos/act@latest

M1 Macで起動する場合は、 --container-architecture linux/amd64 オプションをつける。

また、以下の問題があり、yarnコマンドが実行できないことがある。

-P または --platform でプラットフォームを指定でき、上のissueに記載されている通り、 -P ubuntu-latest=ghcr.io/catthehacker/ubuntu:js-latest をつけると、yarnコマンドを実行できるようになった。

初回実行後に作成される /.actrc にこれらのオプションを記載しておくと、毎回の実行時にこれらのオプションをつけて実行できる。

--container-architecture linux/amd64
-P ubuntu-latest=ghcr.io/catthehacker/ubuntu:js-latest

基本的な実行方法

引数に何もつけないとpushイベントをトリガーとしたワークフローが実行される。

$ act

pull_request イベントをトリガーとしたい場合は以下のように引数を与える。

$ act pull_request

イベントでトリガーされるワークフロー一覧

イベント名の後に -l または --list オプションをつける。

$ act pull_request -l
Stage  Job ID      Job name    Workflow name  Workflow file    Events      
0      helloworld  helloworld  Hello World    helloworld.yaml  pull_request

ワークフローを直接指定して実行

-W または --workflows で指定したワークフローを実行することができる。

$ act -W .github/workflows/helloworld.yaml pull_request

Pull Requestのマージをローカル実行

マージ時にhello worldと表示するテスト用のworkflowを作成する。 (.github/workflows/helloworld.yaml)

name: Hello World

on:
  pull_request:
    types: [closed]

jobs:
  hello-world:
    if: github.event.pull_request.merged == true
    runs-on: ubuntu-latest
    steps:
      - run: echo "Hello World!"

これをローカルで実行してみると、イベントのプロパティが存在しないため、エラーが発生する。

$ act -W .github/workflows/helloworld.yaml pull_request
Error:  Error in if-expression: "if: github.event.pull_request.merged == true" (Unable to dereference 'merged' on non-struct 'invalid')

これを回避するには、以下のようなイベント情報を書き込んだ、jsonファイルを用意し、 -e または --eventpath で指定する。

参考: https://github.com/nektos/act/wiki/Beginner's-guide#using-event-file-to-provide-complete-event-payload

$ cat pull_request.json
{
  "pull_request": {
    "merged": true
  }
}
$ act -W .github/workflows/helloworld.yaml pull_request -e pull_request.json
[Hello World/hello-world] 🚀  Start image=ghcr.io/catthehacker/ubuntu:js-latest
[Hello World/hello-world]   🐳  docker pull image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 username= forcePull=false
[Hello World/hello-world]   🐳  docker create image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker run image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker exec cmd=[mkdir -m 0777 -p /var/run/act] user=root workdir=
[Hello World/hello-world] ⭐  Run echo "Hello World!"
[Hello World/hello-world]   🐳  docker exec cmd=[bash --noprofile --norc -e -o pipefail /var/run/act/workflow/0] user= workdir=
| Hello World!
[Hello World/hello-world]     Success - echo "Hello World!"

pull_request.json のmergedをfalseに書き換えると、実行されないことが確認できる。

$ cat pull_request.json
{
  "pull_request": {
    "merged": false
  }
}
$ act -W .github/workflows/helloworld.yaml pull_request -e pull_request.json

シークレットがある場合

オプション -s または --secret を使って、シークレットキーを指定できる。複数ある場合は、複数つける。

$ act -s SECRET1=value1 -s SECRET2=value2

実行例(.github/workflows/helloworld.json)

name: Hello World

on:
  push:

jobs:
  hello-world:
    runs-on: ubuntu-latest
    steps:
      - run: echo "secret1=${{ secrets.SECRET1 }}, secret2=${{ secrets.SECRET2 }}"

実行結果

 % act -W .github/workflows/helloworld.yaml -s SECRET1=s1 -s SECRET2=s2
[Hello World/hello-world] 🚀  Start image=ghcr.io/catthehacker/ubuntu:js-latest
[Hello World/hello-world]   🐳  docker pull image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 username= forcePull=false
[Hello World/hello-world]   🐳  docker create image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker run image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker exec cmd=[mkdir -m 0777 -p /var/run/act] user=root workdir=
[Hello World/hello-world] ⭐  Run echo "secret1=${{ secrets.SECRET1 }}, secret2=${{ secrets.SECRET2 }}"
[Hello World/hello-world]   🐳  docker exec cmd=[bash --noprofile --norc -e -o pipefail /var/run/act/workflow/0] user= workdir=
| secret1=***, secret2=***
[Hello World/hello-world]   ✅  Success - echo "secret1=${{ secrets.SECRET1 }}, secret2=${{ secrets.SECRET2 }}"

--secret-file でシークレットをファイルに書いて、指定してもよい。

$ cat .secret
SECRET1=s1
SECRET2=s2

$ act -W .github/workflows/helloworld.yaml --secret-file .secret
[Hello World/hello-world] 🚀  Start image=ghcr.io/catthehacker/ubuntu:js-latest
[Hello World/hello-world]   🐳  docker pull image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 username= forcePull=false
[Hello World/hello-world]   🐳  docker create image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker run image=ghcr.io/catthehacker/ubuntu:js-latest platform=linux/amd64 entrypoint=["/usr/bin/tail" "-f" "/dev/null"] cmd=[]
[Hello World/hello-world]   🐳  docker exec cmd=[mkdir -m 0777 -p /var/run/act] user=root workdir=
[Hello World/hello-world] ⭐  Run echo "secret1=${{ secrets.SECRET1 }}, secret2=${{ secrets.SECRET2 }}"
[Hello World/hello-world]   🐳  docker exec cmd=[bash --noprofile --norc -e -o pipefail /var/run/act/workflow/0] user= workdir=
| secret1=***, secret2=***
[Hello World/hello-world]   ✅  Success - echo "secret1=${{ secrets.SECRET1 }}, secret2=${{ secrets.SECRET2 }}"

最近、Qdrantというベクトル検索のOSSを読んでいて、タイトルのやり方を知ったのでまとめておく。

部分ソート

検索のレスポンスを作るときなど、長さNの配列のうち値の高い順にk件だけソートしてほしいという時がある。

このような処理はこれまで何度も書いてきたけど、これまでは単純に配列をソートして、上位k件を切り取るという方法で実装していた。

Rustで例を作ると、高い順にk件取得する処理は、以下のような処理を書くイメージ。

fn top_k(v: &mut [i32], k: usize) -> &[i32] {
    v.sort_by_key(|&x| std::cmp::Reverse(x));
    &v[0..k]
}

fn main() {
    let mut v = vec![1, 8, 2, 4, 3];
    assert_eq!(top_k(&mut v, 3), vec![8, 4, 3]);
}

PythonやRustなどではティムソート (wikipedia) というソートアルゴリズムが利用されていて、最良でO(n)、平均と最悪は共にO(nlog(n))の計算量となる。そのため、上記のtopkの処理はO(nlog(n))の計算量となる。

この方法は、単純であるけれど、配列の全件を並べてしまっているため、少し効率が悪くなってしまう。例えばN=10000、k=100のように全体の長さと欲しい件数に差がある場合は、より効率的な方法が存在しそうなことがイメージしやすいと思う。

このようなケースでは、 部分ソート (wikipedia) と呼ばれる手法を用いると効率よく並べることができる。

wikipediaを参照すると、代表的な3つの手法が紹介されているものの、ヒープ構造を使ったアルゴリズムは、BinaryHeapが標準ライブラリで提供されることが多いので、最も簡単に実装できそう。

今回は、N件ある配列のうち、上位k件を降順に取得したいとして考えてみる。

BinaryHeapの計算量

まず、BinaryHeapの各処理における計算コストについて以下の通りになる。

  • 最大値の参照:O(1)
  • 要素の追加(push): 平均O(1)、最悪O(logn)
  • ルートの削除(pop):平均O(logn)、最悪O(logn)

最大値がルートのノードとなるMaxHeapを考えてみる。

Image
https://en.wikipedia.org/wiki/Binary_heapより引用

最大値を取得したい場合は常にルートを見れば良いので、最大値はO(1)となる。

追加する場合がO(1)になるのが少し直感的に分かりにくいけれど、二分木は全体の約半分は葉のノードとなるため、50%の確率で末端への追加になることに注目する。その上を辿る確率は25%、その上は12.5%、となるため、O(1) * 0.5 + O(2) * 0.25 + O(3) * 0.125 + … となり、これはO(2)となる。

ルートを削除する場合は、ツリーの高さ分だけノードを交換していく必要があるので、O (logn)となる。

  1. Pushのコストについて(stackoverflow)
  2. walframalphaでの数列の和の計算

k件の最小ヒープを使って値を取得する方法

ノードのサイズがk個になるまでヒープにデータをpushしていき、サイズがkに達した後はヒープに追加するたびにルートをpopすることでサイズをkに保つようにする。これにより、ヒープに含まれているノードは上位k個の要素が残ることになる。

計算量について考える。k個のノードを持つヒープを挿入コストはO(logk)で、これをn回繰り返すので全体のコストはO(nlogk)となる。これはサイズがkで固定となるため、nが大きくてメモリに乗り切らないような場合やオンラインで値を更新したい場合はよさそう。

この方法がQdrantで採用されていた。(Qdrantの実装

最大ヒープを構築してk回値をpopする方法

こちらの方法は全件(n件)の最大ヒープを構築して、最大値をk回popする方法。popするたびに最大値が取得できるので、順に配列に追加していくと、最大値でソートされた値を取得することができる。

優先度付きキューそのままの使い方になるので、こちらの方が分かりやすいし簡単かもしれない。

計算量はヒープの構築がO(n)で、k回のpopはO(klogn)、合計するとO(n+klogn)となる。nが大きい場合は、lognで抑えれられてる分だけ前述の方法に比べて計算量は小さくなると思う。

Image
k=100のとき、n=100~10000の範囲での計算量

その他の方法

WikipediaのPartial sortingのページには、クイックセレクトとクイックソートを組み合わせる方法や、マージソートとクイックソートを組み合わせる方法が紹介されており、それらはO(n+klogk)となるのでより効率的になる。

他には、k個目まで要素で終了するバブルソートを使った方法もあり、その場合、O(nk)となる。(wikipediaのバブルソートのページ

参考までに、それぞれ計算量がどれくらいになるのか可視化した。

Image
k=100、n=100~10000で比較。O(n+klogn)とO(n+klogk)が重なって表示されている

Image
k=1000、n=1000~10000で比較。kが大きいと2項目の影響が大きくなるので、O(n+klogk)がより効率的であることがわかる

はじめに

情報検索・検索技術 Advent Calendar 2021の16日目の記事です。 こちらの記事では類似画像検索に関して、全体像がわかるようにざっと説明したいと思います。

類似画像検索とは

類似画像検索とは、画像をクエリとして類似画像を検索する仕組みのことです。検索として一般的にイメージされるテキスト検索は、目的のドキュメントをテキストによって検索しますが、類似画像検索は画像を使って目的の画像を探し出します。

例えば、独特な模様の服を探し出したいとき、なかなか特徴を言語化するのが難しく、テキストで検索するのが難しいというケースがあるかと思います。このように見た目が重要な場合は類似画像検索の方が適していることがあります。

Google画像検索にも画像で検索する機能があるため、使ったことがある方も多いのではないかと思います。

Google画像検索

Google画像検索

全文検索のおさらい

全文検索では、テキストクエリにマッチするドキュメントをもれなく見つけ出す必要がありますが、検索のたびにドキュメント全体を愚直に走査すると、検索対象のドキュメント数が大きいほど、どんどん検索時間がかかるようになってしまいます。

これを防ぐため、事前にインデックス(転置インデックス)を構築することで、検索時間を高速化することができます。本の索引はキーワードごとにページ番号のリストを保持していますが、全文検索も同様にキーワードごとにドキュメントIDのリストを保持しています。

また、日本語のドキュメントのキーワードを作るためには、形態素解析を行なってキーワード単位に分解する必要があります。日本語の文章はどこで区切るか自明ではなく、また検索には不要な助詞などを除去したいためです。

最後に、tf-idfなどでキーワードの出現頻度など考慮して文書の類似度の高い順に検索結果を提示します。これがテキスト検索で行われている処理です。

全文検索ではApache LuceneというJava製の検索ライブラリがよく利用されます。これを利用すると上記の処理を行うことができます。

実際の検索システムではパフォーマンスや可用性のためのクラスタ化などまで行う必要があり、これらを全てを自前で実装するのはかなり大変であるため、普通はApache SolrやElasticsearchなどの全文検索システムを活用して検索システムを構築します。これらは内部的にLuceneを利用しています。

類似画像検索の仕組み

画像のベクトル化

類似画像検索では画像をクエリとして画像を検索します。類似画像検索では画像をベクトルに変換して、ベクトルを用いて画像の類似度を計算できるようにします。

画像をベクトルに変換すると、ベクトル同士のユークリッド距離やコサイン類似度などで類似度を計算できます。

画像をベクトルに変換するとベクトルの距離などで類似度計算が可能になる

画像をベクトルに変換するとベクトルの距離などで類似度計算が可能になる

ここで前提として、似た画像は近い距離となるようにベクトル化していました。どのように画像をベクトル化するかは類似画像検索において重要なポイントといえます。

特徴抽出 (Feature extraction)

画像が8ビットカラーで表現されているとき、各ピクセルは(赤、緑、青)がどれだけ含まれているか、それぞれ0から255の範囲で表現されます。この数値をまとめて画像のベクトルとすることができます。

しかし、この数値列をそのまま使って画像の類似度計算をしても良い精度にはならないです。検索対象の物体の位置ズレの影響を受けたり、検索対象外の領域がノイズとなってしまうためです。

このような問題に対処するため、画像の中で有用な特徴を抜き出してベクトル化を行います。この処理を特徴抽出と呼びます。近年では画像の特徴抽出は畳み込みニューラルネットワーク(Convolutional Neural Network; CNN)を用いて行われます。

CNNは、画像認識の分野で大きな成果を上げており、ディープラーニングを用いると人間に匹敵する精度で画像を認識できます。CNNを用いて画像認識を行うと、同じカテゴリであれば多少の位置ズレがあったり少々色味や形が違っていても高い精度で認識できます。

画像検索においては、画像の一部領域で検索したいことが多いです。物体検出(Object detection)の技術を用いて、画像に存在する物体の領域やカテゴリを検出することができます。

例えば顔の類似画像検索をしたい場合、顔領域以外は不要な領域なので、顔領域のみ切り取って特徴抽出します。

画像の中から顔領域だけ抜き出す図

画像の中から顔領域だけ抜き出す図

類似画像検索に適した特徴ベクトルを作るために、距離学習(Metric Learning)を用いることができます。近年は深層学習技術を使った深層距離学習(Deep Metric Learning)が用いられます。

距離学習を用いると、同じクラスは特徴空間上でより近くに、異なるクラスは遠くになるようなベクトルに変換できるように学習します。距離学習では最終的な出力はベクトルとなります。距離学習はクラス数が多く、サンプル数が少ないケースで有用であると言えます。

類似画像検索では類似度計算のためのベクトルに変換を行いたいため、距離学習が適していると言えます。

近似最近傍探索

類似画像検索もテキスト検索と同様に、検索時に全走査して類似度計算を行うと、検索対象のサイズが大きくなるほど検索時間がかかるようになってしまいます。

実際の検索システムの場合、検索が現実的な時間で終わることは重要です。また、納得感のある結果を得るためには、より多くのデータを検索対象にする必要があります。そもそも似ているデータが存在していなければ、似ている画像を提示することができないからです。

この問題に対処するため、類似画像検索では近似最近傍探索(Approximate Nearest Neighbor; ANN)の技術を利用します。

近似最近傍探索は、ある程度精度を犠牲することで検索処理を高速化します。実際の類似画像検索システムでは、ある程度精度を保ちつつ速度を高速化したいため、近似最近傍探索は必要な技術と言えます。

全文検索では転置インデックスを事前に構築しますが、それと同じように近似最近傍探索でも事前にANNのインデックスを構築します。

代表的なANNライブラリーとして、AnnoyNMSLIBFaissNGTなどがありあります。これらは全文検索でのLuceneに対応するライブラリと言えます。

近似最近傍探索は、精度、速度、メモリーのトレードオフが成立します。用いるデータ数や速度など要件に応じて、手法を選定します。近似最近傍探索の手法については東京大学の松井先生の近似最近傍探索の最前線のスライドが参考になります。

おわりに

この記事では、実際の類似画像検索システムの仕組みについてざっと説明しました。

宣伝になってしまうのですが、類似画像検索についてまとめた本を過去に技術書典で頒布していました。また今後、上記の本の内容を修正&追記した本が技術の泉シリーズから販売される予定です。

こちらの本では、実際に類似画像検索システムを実装しながら理解していく構成になっています。また距離学習や近似最近傍探索についても、もう少し踏み込んで解説する予定です。もし興味がある方はそちらも読んでみていただければと思います!

最近、生成モデルを利用して何か作ってみようかと、生成モデルを学んでいます。 今回は最も基本的な敵対的生成ネットワーク(GAN)の基本的な仕組みについて説明します。

GANとは生成モデルの一種であるため、まずは生成モデルについて説明します。

生成モデルとは

生成モデルとは、写真やイラスト、音楽など実際のデータそのものを出力することができる機械学習モデルです。

機械学習では、画像にうつっているものを識別したり、真偽を判定したりと、未来数値の変化を予測したりといった用途でよく利用されています。

画像の分類などを行うモデルは識別モデルと呼ばれていて、2値分類の場合は、分類ラベルを表す「0」、「1」となります。

生成モデルはこれらのモデルとは異なり、入力はランダムなベクトルをとり、出力は画像などのデータそのものとなります。

Image

機械によって存在しないデータを作り出すことができる、まさに人工知能的な技術といえると思います。

敵対的生成ネットワーク(GAN)とは

生成モデルの中で、近年成功を収めている手法が、敵対的生成ネットワーク(GAN)です。

現在成功をおさめているGANは生成モデルに深層学習技術を利用しており、従来よりも精度の高い出力を可能にしました。

たとえば、こちらはGANの一種のStyleGANという手法で生成した画像ですが、 かなり写真に近い画像を生成することができています。

Image
引用:https://en.wikipedia.org/wiki/StyleGAN

GANの基本

GANはデータを作り出す生成モデルと本物と偽物を見分ける識別モデルの2つを組み合わせて学習を行います。

識別モデルは生成モデルのデータ(ラベル: 0)と本物のデータ(ラベル: 1)を使って学習を行います。学習を行うことで識別モデルはデータが本物か偽物かを見分けられるようになります。

Image

生成モデルはランダムな数値ベクトルを入力として、出力を行いますが、学習時は識別モデルと連結し、 識別モデルの出力が本物と判定されるように生成モデルの学習を行います。 つまり、生成モデルは後段の識別モデルが本物の画像だと間違えてしまう画像を作るように学習します。

Image

生成モデルの学習時に注意が必要なのは、識別モデルのパラメータは変更しないことです。 識別モデルは忖度せずに厳しく判定を行なってほしいですが、学習時に識別モデルを変化させてしまうと、 入力画像が似ていなくても本物と識別してしまう弱い識別モデルになってしまいます。

TensorFlowを使ったGANの実装

今回はTensorFlowを使ってGAN(DCGAN)を実装してみたいと思います。 GANの学習は時間がかかるため、GPUが利用できる環境で実行したほうが良いです。 今回は無料でGPUが利用できるGoogleのColaboratory で学習を行います。

  • (1) データの読み込み
  • (2) 生成器の作成
  • (3) 識別器の作成
  • (4) モデルのコンパイル
  • (5) モデルの学習

の順に説明します。実際に動かしたコードはこちらで公開しています。

(1) データの読み込み

TensorFlowのパッケージに付属するmnistのデータを利用して学習を行ってみます。 データのshapeは「(60000, 28, 28)」で返ってきますが、モデルへの入力を「(28, 28, 1)」にしたいため「reshape」します。

from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

また、このデータは「0~255」まで整数を取りますが、入力値を適切にスケーリングしないと上手く学習できないため、 「-1 ~ 1」の範囲を取るように変換します。

x_train = (x_train - 127.5) / 127.5
x_test = (x_test - 127.5) / 127.5

(2) 生成器の作成

KerasのFunctional APIを利用してモデルを作成します。まずモデルで利用するパッケージをimportします。

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, \
   Dense, Conv2D, LeakyReLU, BatchNormalization, \
   Reshape, UpSampling2D, Activation

後で利用する定数を定義します。 kernel_size はモデルの畳み込み層で利用するカーネルサイズです。 noize_vector_dim は生成モデルの入力ベクトルのサイズです。

kernel_size = 5
noize_vector_dim = 100

今回は最終的に 28 x 28 のサイズの画像を作成したいため、 アップサンプリング層 で倍々に増やしたときに、 ちょうど28となるように最初の畳み込み層への入力サイズを決めます。

アップサンプリング層は、行と列を繰り返してサイズを大きくする層です。 参考: https://keras.io/api/layers/reshaping_layers/up_sampling2d/

7 x 7とした場合、 7 x 714 x 1428 x 28 となるため、 最初のサイズは 7 x 7 x N とします。今回は N = 64 としました。

input_layer = Input(shape=(noize_vector_dim,))

x = Dense(units=7*7*64)(input_layer)
x = BatchNormalization(momentum=0.8)(x)
x = Reshape(target_shape=(7, 7, 64))(x)

アップサンプリング層畳み込み層バッチ正規化層活性化層を順に適用し繰り返します。

x = UpSampling2D()(x)
x = Conv2D(filters=128, strides=1,
          kernel_size=kernel_size, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU()(x)

x = Conv2D(filters=64, strides=1,
          kernel_size=kernel_size, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU()(x)

x = Conv2D(filters=64, strides=1,
          kernel_size=kernel_size, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU()(x)

バッチ正規化層を入れると学習が安定することが知られています。

最後に畳み込み層のフィルタの数を1にして、 (28 x 28 x 1) の出力にします。 最後に tanh を活性化関数とします。これにより出力が -1 ~ 1 の範囲を取るようになります。

DCGANの論文では、全ての層で正規化層を入れるとモデルが不安定になるため、 生成器の出力層ではバッチ正規化を行っておらず、ここでも正規化層を入れていないです。 参考: https://arxiv.abs/1511.06434

x = Conv2D(filters=1, strides=1,
          kernel_size=kernel_size, padding='same')(x)
output_layer = Activation("tanh")(x)

入力層と出力層を Model に渡し、生成器を作成します。

generative_model = Model(input_layer, output_layer)

modelのサマリーは以下のようになります。

generative_model.summary()

Image

(3) 識別器の作成

識別用のモデルは、畳み込み層活性化層ドロップアウト層で構成します。 出力層の手前で平坦化層で平坦化し、出力層はシグモイド関数を使って0~1の範囲で出力を行います。

DCGANの論文 では識別器でも入力層以外はバッチ正規化を入れていたようですが、一般に識別器では利用しない方が良いようです。

識別器の学習ではバッチ正規化は本物のバッチと偽物のバッチでパラメータを決定しますが、 生成器では全て偽物のバッチの正規化となってしまい、学習がうまくいかなくなるようです。 (こちらのコメントを参考 )

識別器では、学習の安定化のためにドロップアウト層を利用します。

kernel_size = 5
input_layer = Input(shape=(28, 28, 1))

x = Conv2D(filters=64, strides=2, kernel_size=kernel_size, padding='same')(input_layer)
x = LeakyReLU()(x)

x = Conv2D(filters=64, strides=2, kernel_size=kernel_size, padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(rate=0.3)(x)

x = Conv2D(filters=128, strides=2, kernel_size=kernel_size, padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters=128, strides=1, kernel_size=kernel_size, padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(rate=0.4)(x)

x = Flatten()(x)
output_layer = Dense(units=1, activation='sigmoid')(x)

discriminator_model = Model(input_layer, output_layer)

モデルのサマリーは以下のようになります。

discriminator_model.summary()

Image

(4) モデルのコンパイル

まず識別器のコンパイルを行います。識別器は2値分類を行うため、損失に binary_crossentropy を指定します。

discriminator_model.compile(
   optimizer=Adam(learning_rate=1e-4),
   loss='binary_crossentropy',
   metrics=['accuracy']
)

生成器の学習のためには二つのモデルを連結したモデルを作成します。ただし、生成器の学習時に識別器の学習をストップする必要があります。

Kerasではモデルの trainable という属性を False にすると、コンパイル後に学習が行われなくなります。コンパイルする必要があるので、先にコンパイルした識別器には影響がないです。

生成器も損失に binary_crossentropy を指定します。

input_layer = Input(shape=(noize_vector_dim,))
discriminator_model.trainable = False

img = generative_model(input_layer)
output_layer = discriminator_model(img)

combined_model = Model(input_layer, output_layer)
combined_model.compile(
   optimizer=Adam(learning_rate=1e-4),
   loss='binary_crossentropy',
   metrics=['accuracy'],
)

(5) モデルの学習

モデルの定義が完了したので、学習を行います。 ただし、GANは学習に時間がかかるため、学習がうまくいっているのか途中経過を確認したいです。

TensorBoardを利用すると学習時の様子が観察できます。 損失を文字で出力して眺めても状況が分かりにくいので、こちらを利用することをお勧めします。

ColaboratoryではTensorBoardを以下のマジックコマンドで起動することができます。

%load_ext tensorboard
%tensorboard --logdir logs

Image
実際にTensorBoardで学習を眺めた様子

学習のログをTensorBoardを反映するため、 tf.keras.callbacks.TensorBoard を初期化します。

batch_size = 64

from tensorflow.keras.callbacks import TensorBoard

tensorboard = TensorBoard(log_dir="logs", histogram_freq=0, batch_size=batch_size, write_graph=True)

各種パラメータと、生成器の入力で必要なランダムベクトルを出力するヘルパー関数を定義します。

batch_size = 64
noize_vector_dim = 100
num_epoch = 50000

def get_random_vector():
   return np.random.normal(0, 1, (batch_size, noize_vector_dim))

まず、生成器の出力画像を利用して、識別器の学習を行います。その後連結したモデルを用いて生成器の学習を行います。 これを指定したエポック数で学習を繰り返します。

for epoch in range(num_epoch):
   # 入力ベクトル
   input_vector = get_random_vector()
   # 生成画像
   fake_image = generative_model.predict(input_vector)
   # 生成画像のラベル (全て0)
   fake_label = np.zeros((batch_size, 1))

   # 本物の画像をサンプリング
   index = np.random.randint(0, x_train.shape[0], size=batch_size),
   valid_image = x_train[index]
   # 本物の画像のラベル (すべて1)
   valid_label = np.ones((batch_size, 1))

   # 学習のログ
   epoch_logs = {}

   # 識別器の学習
   fake_loss, fake_acc = discriminator_model.train_on_batch(fake_image, fake_label)
   valid_loss, valid_acc = discriminator_model.train_on_batch(valid_image, valid_label)
   epoch_logs.update({'disc_loss': 0.5*(fake_loss + valid_loss), 'disc_acc': 0.5*(fake_acc + valid_acc)})

   # 生成器の訓練
   input_vector = get_random_vector()
   label = np.ones((batch_size, 1))
   logs = combined_model.train_on_batch(input_vector, label, return_dict=True)
   epoch_logs.update({'gen_'+key: logs[key] for key in logs})

   tensorboard.on_epoch_end(epoch, epoch_logs)

tensorboard.on_train_end(None)

学習経過はTensorBoardで確認を行います。lossやaccuracyは激しく振動しているため、Smoothingを「0.95~0.98」など強めにすると分かりやすいです。

以下はそれぞれ、識別器のaccuracy、識別器のloss、生成器のaccuracy、生成器のlossの経過です。accuracyやlossなどが途中で常に一定の値をとってしまったりすると学習が失敗しています。

Image

Image

Image

Image

学習の注意点

入力画像のスケールを「-1~1」に変更することを忘れないようにしましょう。 私は何度も学習が失敗し続けていたのですが、モデルにミスがあるのかと思いずっとハマってました。

また、最適化の学習率で結果が変わるので、失敗する場合は、学習率を変更してみると改善するかもしれないです。

精度が悪い場合は、単にエポック数が足りないかもしれないです。最初は1000~2000程度でなかなか上手くいかないなと悩んでいたのですが、こちらのリポジトリでは18kエポックほど学習していたので学習回数を増やしてみたところ改善しました。

実験結果

以下のように画像をグレースケールで描画します。

for i in range(200):
    v = get_random_vector()
    y = generative_model.predict(v)
    plt.imshow(y[0, :, :, 0] * 0.5 + 0.5, cmap='gray')
    plt.figure()

50000エポックほど学習した際の出力画像です。手描きの数字画像を緺麗に生成できていると思います。

Image

同じ3という画像をとってみても、色々な出力が行われていることがわかります。

Image

参考

参考書籍

iOS14からIDFAを収集する場合、アプリユーザーの同意が必要になります。広告の計測などを行う場合、事前に許可を求める必要があります。

IDFA (Identifier for Advertisers)

Appleが端末に割り当てるデバイスIDのことです。このIDを利用して広告配信のエンゲージメントやコンバージョンの測定などに利用します。

参考: https://www.adjust.com/ja/glossary/idfa/

ATT (App Tracking Transparency)

iOS14では、許可していないユーザーのIDFAの取得ができなくなりました。そのため、広告のトラッキングなどを行いたい場合、IDFA取得の同意を求める必要があります。

■ オプトイン

広告配信の際に事前に許可を求めること。

■ ATTオプトイン率

トラッキングを許可したユーザー数。 AppsFlyerのこの記事によると、 アプリ全体で41%程度、アプリ一つあたりだと28%とのこと。また、非ゲームアプリの方がオプトイン率が高くなるとのことです。

参考: https://www.appsflyer.com/jp/ios-14-att-dashboard/

LAT (Limit Ad Tracking)

iOS14以前は、追跡型広告の制限はLimit Ad Tracking (LAT) という機能で行ってました。 これはオプトアウト方式、すなわちデフォルトではIDFAを取得できるものの、後から制限する方式でした。

■ LAT ONユーザー

LAT機能をONにしているユーザーのことです。

• • •

開発者としての対応

もし、IDFAの取得を行う場合は、info.plistにTracking Usage Descriptionの設定を行う必要があります。

この設定値として、許可を求める際の文言を設定します。

オプトイン率は3~4割とのことのため、ユーザー目線に立って文章を考えましょう。 (参考: ATTオプトイン率を上げる5つの方法

また、アプリ側の実装で、ATT許可の要求を必要な場面で行います。

参考: https://zenn.dev/yorifuji/articles/ios-app-tracking-transparency

参考 React Nativeの場合

React NativeでAppsFlyerなど利用して広告の効果測定を行っている場合は、 iOS14以降はATT対応のための設定を行う必要があります。行わない場合、非オーガニックの判定ができないです。

AppsFlyerでは起動時にIDFAを取得し送信しようとしますが、ATT許可を行うまで送信を待つ必要があります。 送信を延期するためのtimeToWaitForATTUserAuthorizationを設定する必要があります。

パーミッションの許可を求める際は、react-native-tracking-transparencyなどを利用して行います。

26件中 11 - 20件を表示