T5 — Text-To-Text Transfer Transformer
T5 — Text-To-Text Transfer TransformerPublished insimplifyai·3 min read·Just now--T5 learns to understand a task from the prefix (like “summarize:”, “translate:”? “translate text2sql: ”, etc.) and then generate the correct text output.EXAMPLE of using T5 with PythonINSTALL : pip install transformers datasets torch.DATA#IMPORTfrom transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArgumentsfrom datasets import Datasetimport torch# GET DATAimport pandas as pdfile_path = ".../data.csv"fileContent = pd.read_csv(file_path)data = pd.DataFrame(fileContent)# Rename the columns to match the expected namesdata = data.rename(columns={ "natural_language_query": "input", "sql_query": "output"})print(data.head()) # View first 5 rowsData ExampleBUILDfrom transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArgumentsfrom datasets import Datasetimport torch# 1. Load tokenizer and modelmodel_name = "t5-small"tokenizer = T5Tokenizer.from_pretrained(model_name)model = T5ForConditionalGeneration.from_pretrained(model_name)# 2. Prepare the datasetdataset = Dataset.from_pandas(data)# 3. Tokenizedef preprocess(example): input_text = "translate text2sql: " + example["input"] target_text = example["output"] model_input = tokenizer(input_text, max_length=64, truncation=True, padding="max_length") label = tokenizer(target_text, max_length=64, truncation=True, padding="max_length") model_input["labels"] = label["input_ids"] return model_inputtokenized_dataset = dataset.map(preprocess)TRAIN# 4. Training argumentstraining_args = TrainingArguments( output_dir="./t5-text2sql", per_device_train_batch_size=4, num_train_epochs=10, logging_dir="./logs", logging_steps=1, save_strategy="no")# 5. Trainertrainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset)# 6. Traintrainer.train()Weights & Biases (aka wandb) : Think of it as Google Analytics for your machine learning training. With wandb, you can:See live training curves (loss, accuracy)View and compare all 3 runs on a dashboardTESTdef generate_sql(nl_query): input_text = "translate text2sql: " + nl_query input_ids = tokenizer(input_text, return_tensors="pt").input_ids output_ids = model.generate(input_ids, max_length=64) return tokenizer.decode(output_ids[0], skip_special_tokens=True)# Example:print(generate_sql("show all ORANGES"))print(generate_sql("List every user"))RESULTSAVE IN GOOGLE DRIVE FOR FUTURE SESSIONfrom transformers import T5Tokenizer, T5ForConditionalGeneration# After fine-tuning, save the modelmodel.save_pretrained(".../my_t5_sql_model")tokenizer.save_pretrained(".../my_t5_sql_model")Folder FormatPython Code to Load and Test the Modelfrom transformers import T5ForConditionalGeneration, T5Tokenizer# Load the model from your saved directorymodel_path = r".../my_t5_sql_model"model = T5ForConditionalGeneration.from_pretrained(model_path)tokenizer = T5Tokenizer.from_pretrained(model_path)# Function to generate SQL from textdef generate_sql(text): input_text = "translate text2sql: " + text inputs = tokenizer.encode(input_text, return_tensors="pt") outputs = model.generate(inputs, max_length=50, num_beams=4, early_stopping=True) sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) return sql_query# Example testtest_input = "get all USERS"sql = generate_sql(test_input)print("Generated SQL:", sql)