Shortcuts

MiniGPT4

class mmpretrain.models.multimodal.MiniGPT4(vision_encoder, q_former_model, lang_encoder, tokenizer, task='caption', freeze_vit=True, freeze_q_former=True, num_query_token=32, prompt_template='###Human: {} ###Assistant: ', raw_prompts=None, max_txt_len=32, end_sym='\n', generation_cfg={}, data_preprocessor=None, init_cfg=None)[source]

The multi-modality model of MiniGPT-4.

The implementation of MiniGPT-4. Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py

Args:

vision_encoder (dict): The config for vision encoder. q_former_model (dict): The config for Qformer. lang_encoder (dict): The config for language model. tokenizer (dict): The config for tokenizer. task (str): To define the task, which control the processing of text.

Defaults to ‘caption’.

freeze_vit (bool): Freeze the training of ViT. Defaults to True. freeze_q_former (bool): Freeze the training of Qformer. Defaults to

True.

num_query_token (int): Number of query tokens of Qformer. Defaults to
prompt_template (str): Prompt template of the model. Defaults to

‘###Human: {} ###Assistant: ‘.

raw_prompts (list): Prompts for training. Defaults to None. max_txt_len (int): Max token length while doing tokenization. Defaults

to 32.

end_sym (str): Ended symbol of the sequence. Defaults to ‘

‘.
generation_cfg (dict): The config of text generation. Defaults to

dict().

data_preprocessor (BaseDataPreprocessor): Used for

pre-processing data sampled by dataloader to the format accepted by forward(). Defaults to None.

init_cfg (dict): Initialization config dict. Defaults to None.

encode_img(images)[source]

The function to encode the images.

forward(images, data_samples=None, mode='predict', **kwargs)[source]

The unified entry for a forward process in both training and test. The method accepts the following modes:

  • “predict”: Forward and return a list of data samples contain the predict results.

Parameters:
  • images (torch.Tensor) – the preprocessed image tensor of shape (N, C, H, W).

  • data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.

  • mode (str) – Return what kind of value. Defaults to ‘predict’.

loss(images, data_samples=None)[source]

The forward function in training.

Parameters:
  • inputs (List[torch.Tensor]) – The input images.

  • data_samples (List[DataSample]) – All elements required during the forward function.

Returns:

A dictionary of loss components.

Return type:

Dict[str, torch.Tensor]

post_process(outputs, data_samples)[source]

Perform post process for outputs for different task.

Parameters:
  • outputs (torch.Tensor) – The generated outputs.

  • data_samples (List[DataSample], optional) – The annotation data of every samples.

Returns:

Return list of data samples.

Return type:

List[DataSample]

prompt_wrap(img_embeds, atts_img, prompt)[source]

The function to wrap the image and prompt.

Currently, the function only supports applying one prompt to all input images in the one batch.

Parameters:
  • img_embeds (torch.Tensor) – The embedding of the input images.

  • atts_img (torch.Tensor) – Attention map of the image embeddings.

  • prompt (str) – The prompt of the batch data.

Returns:

The embedding and attention map.

Return type:

Tuple[torch.Tensor, torch.Tensor]