Source code for dialz.dataset

import json
import os

from dataclasses import dataclass
from typing import List
from dotenv import load_dotenv
from transformers import AutoTokenizer

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

@dataclass
class DatasetEntry:
    """
    Represents a single entry in the dataset, consisting of a positive
    and a negative example.
    """

    positive: str
    negative: str


[docs] class Dataset: """ A class to manage a dataset of positive and negative examples. """ def __init__(self): """ Initializes an empty dataset. """ self.entries: List[DatasetEntry] = []
[docs] def add_entry(self, positive: str, negative: str) -> None: """ Adds a new DatasetEntry to the dataset. Args: positive (str): The positive example. negative (str): The negative example. """ self.entries.append(DatasetEntry(positive=positive, negative=negative))
[docs] def add_from_saved(self, saved_entries: List[dict]) -> None: """ Adds entries from a pre-saved dataset. Args: saved_entries (List[dict]): A list of dictionaries, each containing "positive" and "negative" keys. """ for entry in saved_entries: if "positive" in entry and "negative" in entry: self.add_entry(entry["positive"], entry["negative"]) else: raise ValueError( "Each entry must have 'positive' and \ 'negative' keys." )
[docs] def view_dataset(self) -> List[DatasetEntry]: """ Returns the current dataset as a list of DatasetEntry objects. Returns: List[DatasetEntry]: The list of all entries in the dataset. """ return self.entries
[docs] def save_to_file(self, file_path: str) -> None: """ Saves the dataset to a JSON file. Args: file_path (str): The path to the file where the dataset will be \ saved. """ with open(file_path, "w") as file: json.dump([entry.__dict__ for entry in self.entries], file, indent=4)
@staticmethod def _apply_chat_template( tokenizer, system_role: str, content1: str, content2: str, add_generation_prompt: bool = True ) -> str: """ Applies the chat template to the given content and returns the decoded output. """ messages = [] # Only add system message if system_role is non-empty if system_role: messages.append({"role": "system", "content": f"{system_role}{content1}."}) messages.append({"role": "user", "content": content2}) tokenized = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt, return_tensors="pt", ) return tokenized
[docs] @classmethod def create_dataset( cls, model_name: str, contrastive_pair: list, system_role: str = "Act as if you are extremely ", prompt_type: str = "sentence-starters", num_sents: int = 300, ) -> "Dataset": """ Creates a dataset by generating positive and negative examples based on a given model, contrastive pairs, and prompt variations. This function uses a tokenizer to process input prompts and applies a chat template to generate positive and negative examples for each variation. The resulting examples are added to a dataset object. Args: cls: The class instance (used for accessing class methods). model_name (str): The name of the pre-trained model to use for tokenization. contrastive_pair (list): A list containing two elements representing the positive and negative contrastive pairs. system_role (str, optional): A string representing the system's role in the chat template. Defaults to "Act as if you are extremely ". prompt_type (str, optional): The type of prompt variations to use. Defaults to "sentence-starters". num_sents (int, optional): The number of prompt variations to process. Defaults to 300. Returns: Dataset: A dataset object containing the generated positive and negative examples. Raises: FileNotFoundError: If the specified prompt variations file does not exist. json.JSONDecodeError: If the prompt variations file is not a valid JSON file. """ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) tokenizer.pad_token_id = tokenizer.eos_token_id file_path = os.path.join(os.path.dirname(__file__), "datasets", "create", f"{prompt_type}.json") with open(file_path, "r", encoding="utf-8") as file: variations = json.load(file) dataset = Dataset() for variation in variations[:num_sents]: # Use the helper function for both positive and negative positive_decoded = cls._apply_chat_template(tokenizer, system_role, contrastive_pair[0], variation) negative_decoded = cls._apply_chat_template(tokenizer, system_role, contrastive_pair[1], variation) # Add to dataset dataset.add_entry(positive_decoded, negative_decoded) return dataset
[docs] @classmethod def load_from_file(cls, file_path: str) -> "Dataset": """ Loads a dataset from a JSON file. Args: file_path (str): The path to the JSON file containing the dataset. Returns: Dataset: A new Dataset instance loaded from the file. """ with open(file_path, "r") as file: data = json.load(file) dataset = cls() dataset.add_from_saved(data) return dataset
[docs] @classmethod def load_dataset( cls, model_name: str, name: str, num_sents: int = 300 ) -> "Dataset": """ Loads a default pre-saved corpus included in the package, re-applies chat templates to each entry, and limits to num_sents. """ base_path = os.path.join(os.path.dirname(__file__), "datasets", "load") file_path = os.path.join(base_path, f"{name}.json") if not os.path.exists(file_path): raise FileNotFoundError(f"Dataset '{name}' not found.") # 1. Load the raw data (list of dicts with "positive" and "negative") with open(file_path, "r", encoding="utf-8") as file: raw_entries = json.load(file) # 2. Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) tokenizer.pad_token_id = tokenizer.eos_token_id # 3. Create a new dataset to store the transformed entries processed_dataset = cls() # 4. Iterate through the first num_sents entries, apply templates for entry in raw_entries[:num_sents]: positive_transformed = cls._apply_chat_template( tokenizer, system_role="", content1="", content2=entry["positive"] ) negative_transformed = cls._apply_chat_template( tokenizer, system_role="", content1="", content2=entry["negative"] ) processed_dataset.add_entry(positive_transformed, negative_transformed) return processed_dataset
def __str__(self) -> str: """ Returns a string representation of the dataset for easy viewing. """ return "\n".join( [ f"Positive: {entry.positive}\nNegative: {entry.negative}" for entry in self.entries ] ) def __getitem__(self, index: int) -> DatasetEntry: """ Allows indexing into the dataset to retrieve a specific entry. Args: index (int): The index of the entry to retrieve. Returns: DatasetEntry: The dataset entry at the specified index. """ return self.entries[index] def __len__(self) -> int: """ Returns the number of entries in the dataset. Returns: int: The number of entries in the dataset. """ return len(self.entries)