Tuning AI Models for Assessment Content Generation

Charles Foster

At Finetune, we are building AI solutions to address some of the most challenging problems in education technology, including automated content generation and AI-powered learning resource classification and recommendations. Because the subject matter our tools must handle spans from K-12 through workforce development, we are investing heavily in methods that allow us to scale up the breadth and depth of what our models cover. Key components of this approach are flexible methods to train specialized neural networks in domains where general-purpose models are insufficient. In this blog post, I would like to share a bit of our journey exploring these methods.

Fine-tuning

Typical fine-tuning of neural language models involves simultaneously optimizing all of their trainable parameters, which can run into many billions for networks such as GPT-J. At scales like these, both the fine-tuning and inference processes are nontrivial, making widespread deployment of these difficult. In our own investigations, a few key issues seemed to loom largest:

  • Simply running these transformer models already presses up against the limits of GPU memory (VRAM), and during fine-tuning there is a direct relationship between the number of parameters being optimized and the amount of additional memory consumed.
  • By modifying all of the parameters in the network, the information flow learned during pre-training may be disrupted, resulting in forgetting and loss of few-shot capabilities.
  • Serving a customized multi-gigabyte model for each use case would create unacceptable latency and cost burdens. 

These combined concerns motivated us to explore other methods from the recent literature to tune our neural language models. Luckily, within the past year the natural language processing research sphere has developed a bevy of methods to cut down the cost of customizing the behavior of pre-trained language models.

Prompt Tuning

The original approach we pursued is called Prompt Tuning or Soft Prompting (Lester et al. 2021). In this method, the parameters of the network from pre-training are held frozen. Instead, we prepend a small number of learnable embedding vectors (typically 10 to 20) in front of the input prompt tokens, and tune these embeddings with the usual language modeling objective on a fine-tuning dataset. These embeddings do not represent tokens of language; we can think of them instead as a dense store of context that the network can condition on—via the attention mechanism—as it makes predictions about the tokens in the sequence.

 
Prompt tuning adds only a small runtime cost to the model, since the soft prompts are in the kilobyte range and can be run through the network in parallel. These features make them attractive for serving many concurrent users, as recent deployments of the technique in AI storytelling have indicated. However, integrating soft prompts into popular frameworks like HuggingFace’s transformers is complex, as the interfaces are largely designed to operate on sequences of token indices rather than dense vectors. In addition, as more context is added between the soft prompt and the generation, we begin to see imbalances between the strength of conditioning on the soft prompt and on the token context. Retaining the ability to flexibly add hundreds of tokens of context at runtime was important for us, as it provides additional fine-grained levers of controllability in the item authoring process. If we want to guide the model to focus on content from a particular page of a textbook, or to author a reading comprehension item, or to provide few-shot examples, long-form contextualization matters.

Low Rank Adapters (LoRA)

We later transitioned to a method called LoRA or Low Rank Adapters (Hu et al. 2021). This technique was developed by researchers at Microsoft working on GPT-3 sized models, and builds on earlier adapter approaches. If we think of a transformer as progressively refining its token latent states with each residual layer, the concept of an adapter is to add a small, input-dependent delta (initialized to a no-op) to those latents at a given layer. This gentle nudge is then able to modulate its behavior downstream by, say, emphasizing the parts of the input that are relevant to the task.


Low rank adapters are a kind of adapter that targets a low rank subspace, which cuts down the number of new parameters that we need to train (from D2 to 2 × D × r, where D is in the thousands). As with soft prompting, we hold the original parameters of the network frozen to preserve whatever knowledge they contain from pre-training, and only adjust these new adapter parameters. In our internal tests, we have seen good indicators from LoRA. Beyond enabling us to tune large models on small hardware budgets, models with adapter layers interspersed also retain much of their original few-shot ability while still adapting to the target domain. Notably, integrating low-rank adapters into other frameworks is straightforward, as we can simply swap out existing linear layers with linear + adapter layers as needed.

Moving Forward

To give a flavor for the kind of improvements we have seen in our explorations, below we share a small selection of generations from models trained with different tuning approaches. These items were all generated using the same three input prompts pertaining to the delivery of nursing care: the first focused on prioritizing the needs of multiple clients and the second on taking action in a team setting.

Item #1: A nurse is assessing a client who is scheduled for outpatient surgery. Which of the following patients would the nurse evaluate next?
(A) A patient with type 1 diabetes
(B) A patient with acute pancreatitis
(C) A patient who requires intravenous antibiotic therapy
(D) A patient with Crohn's disease

Item #2: A client is admitted to the hospital with the diagnosis of congestive heart failure. The client's mother is refusing treatment. Which of the following actions should the nurse take?
(A) Notify the provider of the situation.
(B) Check the patient's medication list.
(C) Check the client's medical chart.
(D) Notify the client's doctor of the client's condition.
Item #3: A nurse is evaluating a patient on the pediatric unit. The client is scheduled for a diagnostic workup, including a chest radiograph. Which of the following patients should the nurse be planning to evaluate next?
(A) A child who has pneumonia and is experiencing fever
(B) A child who has a rash and is not responding to a medication
(C) A child who has a fever and has a temperature >100.4
(D) A child who has a fever and a rash

Item #4: A client with renal failure is in the hospital and is scheduled to undergo a renal transplant. The client has a strong belief that the transplant surgeon is incompetent. The client refuses to allow a doctor from another hospital to perform the transplant. Which of the following actions should the charge nurse take?
(A) Take a statement from the client that the transplant surgeon is competent and give it to the transplant surgeon.
(B) Ask the client for a signed statement to back up her assertion.
(C) Contact the transplant surgeon and find out why the client refuses to allow the surgeon from another hospital to perform the transplant.
(D) Inform the transplant surgeon that the client is refusing the transplant.
Item #5: A nurse is planning to assign clients to various bedside stations. Which of the following clients should the nurse assign to the bedside station closest to the nurses' station?
(A) A client who has a suspected diagnosis of a pulmonary embolism and requires a computed tomography (CT) scan
(B) A client who needs assistance with a bowel movement
(C) A client who has an emergency indication for a chest X-ray
(D) A client who requires a chest X-ray because she has a cough

Item #6: An LPN is caring for a client with an indwelling urinary catheter and is asking the nurse to help her clean the catheter. Which of the following actions should the nurse take?
(A) Explain to the client that she will need to provide the cleaning solution and will need to obtain the client's consent for the procedure.
(B) Ask the LPN for assistance.
(C) Offer to help the client clean the catheter.
(D) Assure the LPN that the nurse will assist her.

We can observe that, while the best items from the baseline model are largely fluent and logically coherent, they tend to be underconstrained (as in #1) or call for very little knowledge of the nursing domain (as in #3). In comparison to the baseline items, items from the prompt tuning and low-rank adapter models contain greater detail in their stimuli, stems, and options. The subject matter is relevant to the domain, calling for specific knowledge in the management of nursing care rather than relying on background knowledge alone. Moreover, the items from the low-rank adapter model have a more consistent form. For instance, the items consistently refer to the “client” as opposed to the “patient”, in accordance with the language that would likely appear in assessments (compare #5 to #1 and #3). It also successfully tracks references to multiple individuals within a scenario (compare #6 to #4).

Improvements to domain coverage, stylistic consistency, and logical coherence can translate into significant improvements in the usefulness of neural language models. This is only the beginning: as the technology matures, even more methods will be discovered to create customized, controllable natural language models at scale. And as those methods are discovered, we will continue to incorporate the best from academia, industry, and independent research into Finetune products.

Sincere thanks to Nick Koprowicz, Jesse Hamer, Saad Khan, and Ogden Morse for providing kind, helpful feedback in the development of this blog post.

References

Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., … & Chen, W. (2021). Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685.

Lester, B., Al-Rfou, R., & Constant, N. (2021). The power of scale for parameter-efficient prompt tuning. arXiv preprint arXiv:2104.08691.