With this post I'm going to share my experiences on fine-tuning GPT-2 to autocomplete cardiology sentences. This post will focus on the experiments I ran and the quality of the generated text. If you just want to play around with the model, you can check it out here. In a follow-up post I will share some of the more technical details in deploying the model.
GPT-2
For those that don't live in the AI world GPT-2 (Generative Pre-training 2) is a deep learning model from OpenAI that can be used to generate text. It achieves this by learning to predict the next word given the precedings words from a large corpus of text. OpenAI trained GPT-2 on a massive dataset of 40GB of general text from the internet. They trained models with 117M (small), 345M (medium), 762M (large) and 1542M (extra-large) parameters. They (eventually) made these models publicly available and Huggingface developed a web app where you can play around with it. If you do try and get it to generate text from the domain of cardiology you will see it just doesn't work. This is not surprising as it has never been trained on 'technical' content that you expect to see in text about cardiology. However, the great thing about this model is that we can build on the general language knowledge it has already learned. This is done by fine-tuning the model on a new dataset for a desired purpose e.g. autocomplete cardiology sentences.
Dataset
In order to fine-tune GPT-2 we first need a dataset. Finding large amounts of text in a specialized domain proved to be quite tricky. In the end I settled for textbooks over journal articles as I assumed the text would be more general. After downloading the ebooks (14 in total) I stripped the text from them and did some basic cleaning such as removal of citation numbers. I then split the dataset into train and validation datasets which were 9.5MB and 1MB respectively. As a reminder OpenAI pretrained this model on a dataset that was 4000 times bigger!
Model
As mentioned there are 4 different sized GPT-2 models. The large and extra-large models require a lot of compute power to train and the small cardiology dataset possibly does not benefit from the extra parameterization. For this reason I focused on fine-tuning the small and medium models. I used the awesome Huggingface transformers library for this. For the experiments I looked at 2 different cases - the first which used the original vocabulary for GPT-2 and the second where I expanded the original vocabulary by adding the 150 most common words in the cardiology corpus that did not already appear in the original vocabulary. As an example of these here were the 10 most common - stent, catheter, myocardial, aortic, ventricular, distal, stenosis, lesion, revascularization and stents.
Pre-trained Model Results
How does the original GPT-2 do on generating sensible cardiology text? In all these experiments I provide some context and you can see the generated text from the model in bold.
Generated text for pretrained medium model
A Fractional Flow Reserve of <0.8% of the total volume of the total volume of
The PROSPECT trial found that the use of the drug was associated with a
FAME II was a landmark trial that brought the case to the Supreme Court. The case
Common risk factors for CAD are: 1. A history of cardiovascular disease
Plaque burden as measured by IVUS. The study was approved by the Inst
Thin capped fibroatheroma (FBC) is a rare, but very
The most accurate diagnostic test for CAD is the T-test, which is a simple,
After stent deployment, the team has been able to deploy the new
A stenosis >50% of the time. The most common sten
Angioplasty was first performed by Dr. Robert J. H. Haldane
You don't have to be a cardiologist to know that these sentences are not sensible. It makes simple mistakes such as associating the word trial to judicial matters as opposed to clinical. I'm also not sure I would want my coronary artery disease (CAD) diagnosed by the T-test. Next, we'll see how fine-tuning the model impacts the generated text for the same context.
Fine-tuned Model Results
Here I will present the results from fine-tuning the small and medium models using both the original vocabulary and the extended vocabulary.
Small model
This model was fine-tuned with a batch size of 2 and a learning rate of 5e-5. We can check against overfitting to the training dataset by tracking the validation dataset loss. Around the 11th epoch the validation dataset loss starts to increase again and we can assume it is starting to overfit to the training dataset. This happens for both the original and extended vocab models. But, how does the generated text look?
Generated text for small model with original vocabulary
A Fractional Flow Reserve of <0.8, a single-stent restenosis with
The PROSPECT trial found a mean of the clinical events in the study of
FAME II was a landmark trial that were randomized to the primary endpoint of death, MI
Common risk factors for CAD are not be considered for the most common risk of the
Plaque burden as measured by IVUS imaging, and IVUS imaging has been shown to
Thin capped fibroatheroma (CAD) is a significant differences in patients
The most accurate diagnostic test for CAD is the use of the use of the use of the
After stent deployment is a stent deployment, a stent deployment
A stenosis >50% of patients with aneurysms are associated with
Angioplasty was first performed by the first-generation stent (CAD)
Generated text for small model with extended vocabulary
A Fractional Flow Reserve of <0.8% in the primary PCI in the primary PCI group
The PROSPECT trial found the use of the use of the use of the
FAME II was a landmark trial that the safety of the safety of the safety of the
Common risk factors for CAD are the most common cause of the patient is the patient
Plaque burden as measured by IVUS is a thin fibrous cap fibroatheroma.
Thin capped fibroatheroma (C) was a significant difference in the use
The most accurate diagnostic test for CAD is the presence of the presence of a significant stenosis in
After stent deployment of stents are used to the stents is
A stenosis >50% stenosis is a significant stenosis is a significant stenosis of
Angioplasty was first performed by the main vessel vessel (PCI) in the
The first thing to notice is that the text does now appear to be from the domain of cardiology so perhaps some examples could fool those unfamiliar with the field. However, text is also repeated which is a known issue with language models when always selecting the most likely next word. Moreover, we can see that there really is very little understanding as most of the generated text does not make sense given the context. For example following the word stent with the acronym CAD. Finally, we can see that the model with the extended vocabulary appears to perform poorer and is quite incoherent.
Medium model
For the medium model the batch size was reduced to 1 in order to fit in GPU memory, while the same learning rate was used. The model starts to overfit after only 7 epochs. I'm not sure whether this is due to the increased capacity of the medium model or the reduction in the batch size. Let's take a look at the generated text again given the same context.
Generated text for medium model with original vocabulary
A Fractional Flow Reserve of <0.8 Has Unparalleled Effect on Left Main Stenosis. A recent study by
The PROSPECT trial found that the combination of routine angiography and PCI with drug-eluting
FAME II was a landmark trial that demonstrated the superiority of PCI versus CABG in patients with multivessel
Common risk factors for CAD are summarized in Table 5-1. The prevalence of hypertension, diabetes, and
Plaque burden as measured by IVUS is a surrogate marker for the presence of atherosclerotic plaque and is
Thin capped fibroatheroma (TCFA) is a rare, but potentially devastating, complication of coronary
The most accurate diagnostic test for CAD is the coronary angiogram. The angiogram is a two-dimensional
After stent deployment, the stent should be removed from the delivery sheath and the guid
A stenosis >50% in diameter is considered significant and should be treated. If the stenosis is
Angioplasty was first performed by Gruntzig in 1929 and his technique was subsequently modified by Braunwald and
Generated text for medium model with extended vocabulary
A Fractional Flow Reserve of <0.8 Undergoing PCI is associated with a significant reduction in
The PROSPECT trial found that the use of a DES in the setting of
FAME II was a landmark trial that randomized patients with STEMI to PCI or CABG. The
Common risk factors for CAD are summarized in Table 4.1. The risk of
Plaque burden as measured by IVUS is a useful tool to assess the severity of coronary
Thin capped fibroatheroma (TCFA) is a rare, but potentially
The most accurate diagnostic test for CAD is the coronary angiography. It is a simple and inexpensive
After stent deployment, the patient is monitored for the duration of the
A stenosis >50% is defined as a vessel with a diameter <2
Angioplasty was first performed by Andreas Grüntzig in 1929. He was
As you can see the generated text is substantially better than that for the small model. It still errs on factual information such as saying that Grüntzig (though this part is correct) performed the first angioplasty in 1929. However, most of the generated text is very plausible. The model manages to generate the correct acronym for thin capped fibroatheroma (TCFA) and generates sensible text about plaque burden. We can also see the model is clearly not a physiologist as it believes that a stenosis >50% is significant.
Summary
These results highlight the difficulty in generating coherent and sensible text in a specialized domain using AI. The small model produced poor results, however the medium sized model produced mostly sensible text which was often factually incorrect. Extending the vocabulary appears to have little impact on the generated text. It is important to remember that we should not expect super intelligent models that are only trained on text. How would a human do at understanding cardiology if they only had access to words in a textbook and could not look at the corresponding images? An interesting avenue of research is combining text and images so that the model can learn a better understanding of the "world".
Finally, it is important to remember this is far from the limits of what is achievable with current models. With a much larger dataset and larger model I'm very confident high quality text can be generated. Further improvements may also be possible with a better decoding strategy but I did not explore this. However, though language models such as GPT-2 are good at capturing statistical co-occurences of entities they are limited in dealing with factual knowledge such as dates or names of people. With this in mind incorporating a knowledge graph may be necessary to achieve this. Please give the model a go by visiting cardioassistai.com and share your generated text on Twitter!