CLIP模型在Android端的应用

Lulin Lv2

CLIP 模型在 Android 端的应用

0. ONNXRuntime

ONNX Runtime 是一个跨平台的推理和训练机器学习框架。支持 Pytorch、TensorFlow、Keras 等框架模型,同时支持在多种平台使用。

1. Google MLKit Translation API

这是由谷歌提供的移动端机器学习套件,同时支持 Android 和 iOS,支持 50 多种语言。MLKit 提供在线翻译模型,也可以离线集成到 APP 中。

1
2
3
4
dependencies {
// ...
implementation 'com.google.mlkit:translate:17.0.2'
}
1
2
3
4
5
6
7
8
englishChineseTranslator.translate(text)
.addOnSuccessListener { translatedText ->
// Translation successful.
}
.addOnFailureListener { exception ->
// Error.
// ...
}

3. Pytorch 转 ONNX

转换 Text 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import clip
from PIL import Image
from torch import Tensor
import torch

# Export ImageEncoder of the CLIP to onnx model
if __name__ == '__main__':
device = "cpu"
# print(clip.available_models())
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
i = Image.open("../../image.jpg")
input_tensor: Tensor = preprocess(i).unsqueeze(0).to(device)
vit = model.visual
vit.eval()

onnx_filename = 'clip-image-encoder.onnx'
torch.onnx.export(vit, input_tensor, onnx_filename)
# python -m onnxsim clip-image-encoder.onnx clip-image-encoder-optimized.onnx

转换 Image 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
from collections import OrderedDict
import torch
import numpy as np
import clip

class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x

class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

def forward(self, x: torch.Tensor):
return self.resblocks(x)

class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)

class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)

class TextEncoder(nn.Module):
def __init__(self,
embed_dim: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()

self.context_length = context_length

self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)

self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.temperature = nn.Parameter(torch.tensor(0.07))

self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

print(f"text_projection shape: {self.text_projection.shape}")
self.dtype = torch.float32

self.initialize_parameters()

def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)

proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
else:
nn.init.normal_(self.text_projection, std=self.custom_text_config['text_rep_size'] ** -0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def forward(self, text):
# print(f'text: {text}')
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]

x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x

# Export ImageEncoder of the CLIP to onnx model
if __name__ == '__main__':
import clip

device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408,
transformer_width=512, transformer_heads=8, transformer_layers=12)

missing_keys, unexpected_keys = text_encoder.load_state_dict(model.state_dict(), strict=False)

text_encoder.eval()

input_tensor = clip.tokenize("a diagram").to(device)
traced_model = torch.jit.trace(text_encoder, input_tensor)

onnx_filename = 'clip-text-encoder.onnx'

torch.onnx.export(text_encoder, input_tensor, onnx_filename)
# python -m onnxsim clip-text-encoder.onnx clip-text-encoder-optimized.onnx

4. Android 调用 ONNX

ImageEncoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
suspend fun encode(bitmap: Bitmap) = withContext<FloatBuffer>(Dispatchers.Default) {
val ortEnv = OrtEnvironment.getEnvironment()
if (ortSession == null) {
ortSession = ortEnv.createSession(
AssetUtil.assetFilePath(SeekingApplication.context, modelPath),
options
)
}

val imageBitmap = preprocess(bitmap)
val floatBuffer = allocateFloatBuffer(floatBufferElementCount)
floatBuffer.rewind()
bitmapToFloatBuffer(
imageBitmap,
0, 0,
224, 224,
normMeanRGB,
normStdRGB,
floatBuffer,
0,
MemoryFormat.CONTIGUOUS,
)
floatBuffer.rewind()

val inputName = ortSession?.inputNames?.iterator()?.next()
val shape: LongArray = longArrayOf(1, 3, 224, 224)
ortEnv.use { env ->
val tensor = OnnxTensor.createTensor(env, floatBuffer, shape)
val output: OrtSession.Result? =
ortSession?.run(Collections.singletonMap(inputName, tensor))
val resultBuffer = output?.get(0) as OnnxTensor
return@withContext (resultBuffer.floatBuffer)
}
}

TextEncoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
suspend fun encode(bitmap: Bitmap) = withContext<FloatBuffer>(Dispatchers.Default) {
val ortEnv = OrtEnvironment.getEnvironment()
if (ortSession == null) {
ortSession = ortEnv.createSession(
AssetUtil.assetFilePath(SeekingApplication.context, modelPath),
options
)
}

val imageBitmap = preprocess(bitmap)
val floatBuffer = allocateFloatBuffer(floatBufferElementCount)
floatBuffer.rewind()
bitmapToFloatBuffer(
imageBitmap,
0, 0,
224, 224,
normMeanRGB,
normStdRGB,
floatBuffer,
0,
MemoryFormat.CONTIGUOUS,
)
floatBuffer.rewind()

val inputName = ortSession?.inputNames?.iterator()?.next()
val shape: LongArray = longArrayOf(1, 3, 224, 224)
ortEnv.use { env ->
val tensor = OnnxTensor.createTensor(env, floatBuffer, shape)
val output: OrtSession.Result? =
ortSession?.run(Collections.singletonMap(inputName, tensor))
val resultBuffer = output?.get(0) as OnnxTensor
return@withContext (resultBuffer.floatBuffer)
}
}
  • Title: CLIP模型在Android端的应用
  • Author: Lulin
  • Created at : 2023-12-04 15:06:20
  • Updated at : 2023-12-04 17:44:44
  • Link: https://blog.lllin.top/2023/12/04/clip-android/
  • License: This work is licensed under CC BY-NC-SA 4.0.
 Comments