MarkovTextGen/MarkovWordGen.java
2024-08-10 11:00:16 -05:00

155 lines
5.2 KiB
Java

import java.util.ArrayList;
import java.util.Arrays;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.Scanner;
/**
* Class used for training the text-generator, and a main method with a basic demo.
*/
public class MarkovWordGen {
private ArrayList<WordNode> wordNodes;
private int nGram;
public static void main(String abcd[]) {
/* Create a bot as a bigram model. */
String data = readFile("./Poe.txt");
MarkovWordGen bot = new MarkovWordGen(2);
bot.train(data);
System.out.println(bot.generate("The thousand", 1000));
}
/**
* Simple constructor.
*/
public MarkovWordGen(int n) {
wordNodes = new ArrayList<>();
nGram = n;
}
/**
* Trains using the parameter "input" as the training data. The parameter "nGram"
* represents whether it is a unigram, bigram, trigram, etc. Anything more than
* trigram will probably not work well, though. The method returns an ArrayList of
* WordNodes that can be used to generate text based on a seed (a starting value
* like "The" for a unigram, or "The dog" for a bigram).
*/
public void train(String input) {
// Convert the one large string input into an array of tokens.
ArrayList<String> tokens = new ArrayList<>(Arrays.asList(input.split("\\s+")));
for(int i = 0; i < tokens.size(); i++) {
tokens.get(i).trim();
}
// Iterate through the tokens to train.
for(int i = 0; i < tokens.size() - nGram; i++) {
// Combine nGram tokens to create a key.
String key = "";
for(int j = 0; j < nGram - 1; j++) {
key += tokens.get(i+j) + " ";
}
key += tokens.get(i + nGram-1);
String next = tokens.get(i+nGram); // The token after the key value.
// Skip over ##STOP##. This is used to separate independent pieces of training
// data. If you are training with Edgar Poe stories for example, you can put
// ##STOP## between stories, and it won't consider the transition from one
// to the next as a natural flow, and will train on each story independantly.
if(next.equals("##STOP##")) {
i += nGram+1;
}
// Find the index of that key in out, if it exists. If not, leave index -1.
int index = -1;
for(int j = 0; j < wordNodes.size(); j++) {
WordNode n = wordNodes.get(j);
// If this key has appeared before...
if(n.getKey().equals(key)) {
// ...then record the index and stop looking for more.
index = j;
}
}
// If this key hasn't been seen before...
if(index == -1) {
wordNodes.add(new WordNode(key, next)); // ...then add the key and value.
} else { // And if it has...
wordNodes.get(index).add(next); // ...then add this value to that WordNode.
}
}
}
/**
* Returns a String generated from the markov chain.
*/
public String generate(String seed, int tokens) {
String[] context = seed.split(" ");
String out = seed + " ";
for(int i = 0; i < tokens; i++) {
// Make strContext, combining the array context into a single string.
String strContext = "";
for(int j = 0; j < context.length - 1; j++) {
strContext += context[j] + " ";
}
strContext += context[context.length - 1];
// See if strContext is the key to any WordNode.
for (int j = 0; j < wordNodes.size(); j++) {
if(wordNodes.get(j).getKey().equals(strContext)) {
// Make room for the new word in the context array.
for(int k = 0; k < context.length - 1; k++) {
context[k] = context[k+1];
}
String newWord = wordNodes.get(j).generateWord(); // Find the new word.
// Add the newly generated word to the context
context[context.length - 1] = newWord;
out += newWord + " ";
}
}
}
return out;
}
/**
* Simple printout of wordNodes in a readable format. Useful for debugging.
*/
public void printNodes() {
for (WordNode n : wordNodes) {
System.out.println(n.getKey() + " - ");
for(int i = 0; i < n.getWordVals().size(); i++) {
System.out.println(" " + n.getWordVals().get(i) + " -- " + n.getNumVals().get(i));
}
}
}
/**
* Reads a text file into one String and returns it.
*/
public static String readFile(String filePath) {
String out = "";
try {
Scanner scn = new Scanner(new File(filePath));
while(scn.hasNextLine()) {
out += scn.nextLine() + " ";
}
scn.close();
} catch(FileNotFoundException e) {
System.err.println("File not found.");
}
return out;
}
}