Using multi-modal large language models for automated image captioning. Rich captions can be used for training Stable Diffusion Dreambooth or LoRAs.
Video:
Links/Resources
Recognize Anything
https://github.com/xinyu1205/recognize-anything
Kosmos-2
https://huggingface.co/microsoft/kosmos-2-patch14-224
BLIP-2 OPT-2.7B 8-bit Quantized Model by Mediocreatmybest
https://huggingface.co/Mediocreatmybest/blip2-opt-2.7b_8bit
In [ ]:
import os
import glob as glob
import shutil
from PIL import Image
from tqdm.auto import tqdm
In [ ]:
#!git clone https://github.com/xinyu1205/recognize-anything.git
#%cd recognize-anything
In [ ]:
#!pip install timm transformers fairscale pycocoevalcap
In [ ]:
filepaths = glob.glob("/FILE DIR HERE/**/*.jpg")
print(len(filepaths))
In [ ]:
badfilelist = []
for imagefile in filepaths:
try:
im = Image.open(imagefile)
im = im.convert("RGB")
except:
print(imagefile + " is invalid")
badfilelist.append(imagefile)
print(badfilelist)
In [ ]:
#set to 1 and run next cell to delete files identified as broken
dorundelete = 1
In [ ]:
if dorundelete==1:
for deletefile in badfilelist:
if os.path.isfile(deletefile):
os.remove(deletefile)
print(deletefile + " deleted")
In [ ]:
###RAM PLUS#Load RAM-Plus
import torch
from PIL import Image
from ram.models import ram_plus
from ram import inference_ram as inference
from ram import get_transform
#ramplus
image_size=384
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = get_transform(image_size=image_size)
#######load model
model = ram_plus(pretrained="pretrained/ram_plus_swin_large_14m.pth",
image_size=image_size,
vit='swin_l')
model.eval()
model = model.to(device)
In [ ]:
###RAM PLUS
import os
from tqdm.auto import tqdm
caption_ext = ".caption"
#Tags to prepend
#specified_tags = "Tag, "
specified_tags = ""
i=0
bar = tqdm(range(0,len(filepaths)))
while i < len(filepaths):
for imagefile in filepaths:
image = transform(Image.open(imagefile)).unsqueeze(0).to(device)
res = inference(image, model)
basefile = os.path.basename(imagefile)
workdir = os.path.dirname(imagefile)
noext = basefile.split(".")[0]
capt_file = noext + caption_ext
print("File: ", capt_file)
modeltags = res[0].replace(" |",",")
if len(specified_tags) < 1:
nothing = ""
else:
nothing = specified_tags + ", "
print(nothing + modeltags)
filehandle = open(workdir+"/" +capt_file,"w")
filehandle.write(nothing + modeltags)
filehandle.close()
i=i+1
bar.update(1)
In [ ]:
###TAG2TEXT
'''
* The Tag2Text Model
* Written by Xinyu Huang
'''
import argparse
import numpy as np
import random
import torch
from PIL import Image
from ram.models import tag2text
from ram import inference_tag2text as inference
from ram import get_transform
image_size=384
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = get_transform(image_size=image_size)
# delete some tags that may disturb captioning
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
#delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
delete_tag_index = []
#######load model
model = tag2text(pretrained="pretrained/tag2text_swin_14m.pth",
image_size=image_size,
vit='swin_b',
delete_tag_index=delete_tag_index)
model.threshold = 0.68 # threshold for tagging
model.eval()
model = model.to(device)
In [ ]:
###TAG2TEXT
import os
from pathlib import Path
from tqdm.auto import tqdm
caption_ext = ".caption"
specified_tags = ""
prepend = ""
i=0
bar = tqdm(range(0,len(filepaths)))
while i < len(filepaths):
for imagefile in filepaths:
image = transform(Image.open(imagefile)).unsqueeze(0).to(device)
basefile = os.path.basename(imagefile)
workdir = os.path.dirname(imagefile)
noext = basefile.split(".")[0]
capt_file = noext + caption_ext
print("File: ", capt_file)
if os.path.isfile(workdir+"/" +capt_file):
print(workdir+"/" +capt_file)
filehandle2 = open(Path(workdir+"/" +capt_file),"r")
read_tags = filehandle2.read()
if os.path.getsize(workdir+"/" +capt_file) > 5:
read_tags = read_tags.replace(", "," | ")
specified_tags = str(read_tags)
res = inference(image, model, specified_tags)
modeltags = res[0].replace(" |",",")
if len(specified_tags) < 1:
nothing = ""
else:
nothing = res[1]+", "
specout = res[1].replace(" |",",")
if len(prepend) < 1:
prepend = ""
print(prepend + res[2] + ", " + modeltags +", " +specout)
filehandle = open(workdir+"/" +capt_file,"w")
filehandle.write(prepend + res[2] + ", " + modeltags +", " +specout)
filehandle.close()
i=i+1
bar.update(1)
In [ ]:
In [ ]:
###BLIP2
In [ ]:
import glob as glob
#openpath = "/mnt/f/positions/img/**/*.png"
openpath = "FILE DIR HERE/**/*.png"
filepaths = glob.glob(openpath)
print(len(filepaths))
In [ ]:
###blip2
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Mediocreatmybest/blip2-opt-2.7b_8bit")
model = Blip2ForConditionalGeneration.from_pretrained("Mediocreatmybest/blip2-opt-2.7b_8bit",
#load_in_8bit=True,
device_map="auto", torch_dtype=torch.float16,
)
model = torch.compile(model)
In [ ]:
#filepaths = filepaths[1735::]
print(len(filepaths))
In [ ]:
###blip2
#for BLIP2
import os
from tqdm.auto import tqdm
caption_ext = ".caption"
#Tags to prepend
#specified_tags = "Tag, "
specified_tags = ""
i=0
bar = tqdm(range(0,len(filepaths)))
while i < len(filepaths):
for imagefile in filepaths:
basefile = os.path.basename(imagefile)
workdir = os.path.dirname(imagefile)
noext = basefile.split(".")[0]
capt_file = noext + caption_ext
print("File: ", capt_file)
image = Image.open(imagefile)
questions = [""]
querycount = 0
responses = []
for query in questions:
if querycount == 0:
prompt = query
inputs=processor(images=image,text=prompt,return_tensors="pt").to(device="cuda",dtype=torch.float16)
generated_ids = model.generate(**inputs,max_length=80,min_length=20)
prompted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
responses.append(prompted_text)
else:
prompt = prompted_text + " " + query
inputs=processor(images=image,text=prompt,return_tensors="pt").to(device="cuda",dtype=torch.float16)
generated_ids = model.generate(**inputs,max_new_tokens=40,num_beams=5,min_length=1,top_p=0.97,repetition_penalty=2.5,length_penalty=1.0,temperature=0.97,do_sample=True)
prompted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
responses.append(prompted_text)
print(prompted_text)
querycount=querycount+1
#final answer using other answers in context
print(prompted_text)
prompted_text = ", ".join(responses)
if len(specified_tags) < 1:
nothing = ""
else:
nothing = specified_tags # + ", "
print(nothing + prompted_text)
filehandle = open(workdir+"/" +capt_file,"w")
if(len(prompted_text)>10):
filehandle.write(nothing + prompted_text)
filehandle.close()
i=i+1
bar.update(1)
In [ ]:
In [ ]:
###kosmos2
import glob as glob
openpath = "file path/**/*.png"
filepaths = glob.glob(openpath)
print(len(filepaths))
In [ ]:
###kosmos2
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224").to(device="cuda",dtype=torch.float16)#,load_in_4bit=True)#,bnb_4bit_compute_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
In [ ]:
#for Kosmos2
import os
from tqdm.auto import tqdm
caption_ext = ".caption"
#Tags to prepend
#specified_tags = "Tag, "
specified_tags = ""
i=0
bar = tqdm(range(0,len(filepaths)))
while i < len(filepaths):
for imagefile in filepaths:
basefile = os.path.basename(imagefile)
workdir = os.path.dirname(imagefile)
print(imagefile)
noext = basefile.split(".")[0]
capt_file = noext + caption_ext
print("File: ", capt_file)
image = Image.open(imagefile).convert("RGB")
questions = ["<grounding>","<grounding> Describe this image of a Pokemon in detail:"]
querycount = 0
responses = []
for query in questions:
if querycount == 0:
prompt = query
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device="cuda",dtype=torch.float16)
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
min_length=60,
max_new_tokens=128,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
_processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=True)
processed_text, entities = processor.post_process_generation(generated_text)
responses.append(processed_text)
else:
###Prepend the previous question/answer for chat-style chained questions
#prompt = processed_text + " " + query
prompt = query
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device="cuda",dtype=torch.float16)
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
min_length=80,
max_new_tokens=128,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
_processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=True)
processed_text, entities = processor.post_process_generation(generated_text)
responses.append(processed_text)
querycount=querycount+1
processed_text = ", ".join(responses)
for removeprompt in questions:
processed_text = processed_text.replace("Describe this image of a Pokemon in detail: ","")
processed_text = processed_text.replace("Describe this image in detail: ","")
processed_text = processed_text.replace("The image features","")
processed_text = processed_text.replace(" "," ")
processed_text = processed_text.replace("' s ","'s ")
processed_text = processed_text.replace(".. ","")
processed_text = processed_text.replace("., ",",")
processed_text = processed_text.replace(". ",",")
if len(specified_tags) < 1:
nothing = ""
else:
nothing = specified_tags # + ", "
print(nothing + processed_text)
filehandle = open(workdir+"/" +capt_file,"w")
if(len(processed_text)>10):
filehandle.write(nothing + processed_text)
filehandle.close()
i=i+1
bar.update(1)