removing the similar sentence greater then 0.95 threshold on text column
def deduplicate_datasets(dataset: Dataset, model: str, threshold: float):
sentence_model = SentenceTransformer(model)
outputs = [example["response"] for example in dataset['train']]
print("Converting text to embeddings...")
embeddings = sentence_model.encode(outputs, show_progress_bar=True)
print("Filtering out near-duplicates...")
similarity_matrix = cosine_similarity(embeddings)
to_keep = set()
for i in tqdm(range(len(embeddings)), desc="Filtering"):
# Check if either the current item or its nearest neighbor is already in the to_keep set
if i not in to_keep:
to_keep.add(i)
# Find indices of elements with similarity above the threshold
similar_indices = np.where(similarity_matrix[i] >= threshold)[0]
# Add the indices to the to_keep set
to_keep.update(similar_indices)
to_keep = list(to_keep)
dataset = dataset['train'].select(to_keep)
return DatasetDict({"train": dataset})