1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
| import os import argparse from PIL import Image import torch from torch import nn import coremltools as ct import cn_clip.clip as clip from cn_clip.clip.utils import _MODELS, _MODEL_INFO, _download, available_models, create_model, image_transform
class ImageEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.clip_model = clip_model
def forward(self, image): return self.clip_model.encode_image(image)
class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.clip_model = clip_model
def forward(self, text): return self.clip_model.encode_text(text)
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model-arch", required=True, choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." ) parser.add_argument( "--pytorch-ckpt-path", default=None, type=str, help="Path of the input PyTorch Chinese-CLIP checkpoint." ) parser.add_argument( "--download-root", default=None, type=str, help="If --pytorch-ckpt-path is None, official pretrained ckpt will be downloaded under --download-root directory and converted." ) parser.add_argument( "--save-coreml-path", required=True, type=str, help="Path (prefix) of the output converted CoreML Chinese-CLIP text or vision model." ) parser.add_argument( "--convert-text", action="store_true", help="Whether to convert the text encoder (text feature extractor) into CoreML." ) parser.add_argument( "--convert-vision", action="store_true", help="Whether to convert the vision encoder (vision feature extractor) into CoreML." ) parser.add_argument( "--precision", default="fp16", choices=["fp16", "fp32"], help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." ) parser.add_argument( "--context-length", type=int, default=52, help="The padded length of input text (include [CLS] & [SEP] tokens)." ) args = parser.parse_args() return args
if __name__ == '__main__': args = parse_args()
print("Params:") for name in sorted(vars(args)): val = getattr(args, name) print(f" {name}: {val}")
if args.pytorch_ckpt_path and os.path.isfile(args.pytorch_ckpt_path): input_ckpt_path = args.pytorch_ckpt_path elif args.model_arch in _MODELS: input_ckpt_path = _download( _MODELS[args.model_arch], args.download_root or os.path.expanduser("./cache/clip")) else: raise RuntimeError( f"Model {args.model_arch} not found; available models = {available_models()}")
with open(input_ckpt_path, 'rb') as opened_file: checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model( _MODEL_INFO[args.model_arch]['struct'], checkpoint).float().eval()
resolution = _MODEL_INFO[args.model_arch]['input_resolution'] preprocess = image_transform(resolution) if args.precision == "fp16": precision = ct.precision.FLOAT16 elif args.precision == "fp32": precision = ct.precision.FLOAT32 image = preprocess(Image.new('RGB', (resolution, resolution))).unsqueeze(0) text = clip.tokenize([""], context_length=args.context_length)
if args.convert_text: text_model = TextEncoder(model) text_model.eval()
text = clip.tokenize([""], context_length=args.context_length).int()
traced_text_model = torch.jit.trace(text_model, text)
text_outputs = [ct.TensorType(name="text_features")] text_coreml_model = ct.convert( traced_text_model, inputs=[ct.TensorType(name="text", shape=text.shape)], outputs=text_outputs, convert_to="mlprogram", compute_precision=precision, minimum_deployment_target=ct.target.iOS15 )
text_coreml_model_path = f"{args.save_coreml_path}.text.mlpackage" print(f"save as {text_coreml_model_path}") text_coreml_model.save(text_coreml_model_path) print( f"Text model converted to CoreML and saved at: {text_coreml_model_path}")
if args.convert_vision: image_model = ImageEncoder(model) image_model.eval()
image_width = 336 if args.model_arch == "ViT-L-14-336" else 224 dummy_image_input = torch.rand(1, 3, image_width, image_width)
traced_image_model = torch.jit.trace(image_model, dummy_image_input)
image_outputs = [ct.TensorType(name="image_features")] image_coreml_model = ct.convert( traced_image_model, inputs=[ct.TensorType( name="image", shape=dummy_image_input.shape)], outputs=image_outputs, convert_to="mlprogram", compute_precision=precision, minimum_deployment_target=ct.target.iOS15 )
image_coreml_model_path = f"{args.save_coreml_path}.image.mlpackage" image_coreml_model.save(image_coreml_model_path) print( f"Image model converted to CoreML and saved at: {image_coreml_model_path}")
|